diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index a958e7350..04c6819eb 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -1413,9 +1413,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.33" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2090,6 +2090,12 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +[[package]] +name = "ipnetwork" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" + [[package]] name = "iri-string" version = "0.7.8" @@ -2280,6 +2286,19 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "maxminddb" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76371bd37ce742f8954daabd0fde7f1594ee43ac2200e20c003ba5c3d65e2192" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2311,11 +2330,12 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -2777,6 +2797,9 @@ dependencies = [ "axum-server", "clap", "config", + "flate2", + "ipnet", + "maxminddb", "ohttp-relay", "opentelemetry", "opentelemetry-otlp", @@ -3822,6 +3845,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "similar" version = "2.7.0" diff --git a/Cargo-recent.lock b/Cargo-recent.lock index a958e7350..04c6819eb 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -1413,9 +1413,9 @@ dependencies = [ [[package]] name = "flate2" -version = "1.0.33" +version = "1.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "324a1be68054ef05ad64b861cc9eaf1d623d2d8cb25b4bf2cb9cdd902b4bf253" +checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", "miniz_oxide", @@ -2090,6 +2090,12 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +[[package]] +name = "ipnetwork" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" + [[package]] name = "iri-string" version = "0.7.8" @@ -2280,6 +2286,19 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "maxminddb" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76371bd37ce742f8954daabd0fde7f1594ee43ac2200e20c003ba5c3d65e2192" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2311,11 +2330,12 @@ dependencies = [ [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "1fa76a2c86f704bdb222d66965fb3d63269ce38518b83cb0575fca855ebb6316" dependencies = [ "adler2", + "simd-adler32", ] [[package]] @@ -2777,6 +2797,9 @@ dependencies = [ "axum-server", "clap", "config", + "flate2", + "ipnet", + "maxminddb", "ohttp-relay", "opentelemetry", "opentelemetry-otlp", @@ -3822,6 +3845,12 @@ dependencies = [ "libc", ] +[[package]] +name = "simd-adler32" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" + [[package]] name = "similar" version = "2.7.0" diff --git a/flake.nix b/flake.nix index 1957ae532..98d804ce3 100644 --- a/flake.nix +++ b/flake.nix @@ -85,6 +85,7 @@ filter = path: type: (builtins.match ".*nginx.conf.template$" path != null) + || (builtins.match ".*\\.mmdb$" path != null) || (craneLibVersions.msrv.filterCargoSources path type); name = "source"; }; @@ -162,7 +163,7 @@ "payjoin-cli" = "--features v1,v2"; "payjoin-directory" = ""; "ohttp-relay" = ""; - "payjoin-mailroom" = "--features acme,telemetry"; + "payjoin-mailroom" = "--features access-control,acme,telemetry"; }; # nix2container for building OCI/Docker images diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index d295e7e5f..1058e0cab 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -65,12 +65,72 @@ fn init_tls_acceptor(cert_key: (Vec, Vec)) -> Result>>, +); + +impl BlockedAddresses { + pub fn empty() -> Self { + Self(Arc::new(tokio::sync::RwLock::new(std::collections::HashSet::new()))) + } + + pub fn from_address_lines(text: &str) -> Self { + Self(Arc::new(tokio::sync::RwLock::new(parse_address_lines(text)))) + } + + /// Replace the contents with scripts parsed from newline-delimited + /// address text. Returns the number of entries after update. + pub async fn update_from_lines(&self, text: &str) -> usize { + let scripts = parse_address_lines(text); + let count = scripts.len(); + *self.0.write().await = scripts; + count + } +} + +/// V1 protocol configuration. +/// +/// Its presence in [`Service`] enables the V1 fallback path; +/// its contents carry optional blocklist screening. +#[derive(Clone, Default)] +pub struct V1 { + blocked_addresses: Option, +} + +impl V1 { + pub fn new(blocked_addresses: Option) -> Self { Self { blocked_addresses } } +} + +fn parse_address_lines(text: &str) -> std::collections::HashSet { + text.lines() + .filter_map(|l| { + let trimmed = l.trim(); + if trimmed.is_empty() { + return None; + } + match trimmed.parse::>() { + Ok(addr) => Some(addr.assume_checked().script_pubkey()), + Err(e) => { + tracing::warn!("Skipping unparsable blocked address {trimmed:?}: {e}"); + None + } + } + }) + .collect() +} + #[derive(Clone)] pub struct Service { db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, - enable_v1: bool, + v1: Option, } impl tower::Service> for Service @@ -94,8 +154,8 @@ where } impl Service { - pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, enable_v1: bool) -> Self { - Self { db, ohttp, sentinel_tag, enable_v1 } + pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, v1: Option) -> Self { + Self { db, ohttp, sentinel_tag, v1 } } #[cfg(feature = "_manual-tls")] @@ -241,7 +301,7 @@ impl Service { B: Body + Send + 'static, B::Error: Into, { - if self.enable_v1 { + if self.v1.is_some() { self.post_fallback_v1(id, query, body).await } else { let _ = (id, query, body); @@ -329,7 +389,7 @@ impl Service { match (parts.method, path_segments.as_slice()) { (Method::POST, &["", id]) => self.post_mailbox(id, body).await, (Method::GET, &["", id]) => self.get_mailbox(id).await, - (Method::PUT, &["", id]) if self.enable_v1 => self.put_payjoin_v1(id, body).await, + (Method::PUT, &["", id]) if self.v1.is_some() => self.put_payjoin_v1(id, body).await, _ => Ok(not_found()), } } @@ -368,6 +428,30 @@ impl Service { let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; handle_peek(self.db.wait_for_v2_payload(&id).await, timeout_response) } + + /// Screen a V1 PSBT body against the address blocklist. + /// + /// Returns `Ok(())` if screening passes or is not configured. + async fn check_v1_blocklist(&self, body_str: &str) -> Result<(), HandlerError> { + if let Some(blocked) = self.v1.as_ref().and_then(|v| v.blocked_addresses.as_ref()) { + let scripts = blocked.0.read().await; + if !scripts.is_empty() { + match screen_v1_addresses(body_str, &scripts) { + ScreenResult::Blocked => { + return Err(HandlerError::V1PsbtRejected(anyhow::anyhow!( + "blocked address in V1 PSBT" + ))); + } + ScreenResult::Clean => {} + ScreenResult::ParseError(e) => { + warn!("Could not parse V1 PSBT: {e}"); + } + } + } + } + Ok(()) + } + async fn put_payjoin_v1( &self, id: &str, @@ -386,6 +470,9 @@ impl Service { return Err(HandlerError::PayloadTooLarge); } + let body_str = std::str::from_utf8(&req).map_err(|e| HandlerError::BadRequest(e.into()))?; + self.check_v1_blocklist(body_str).await?; + match self.db.post_v1_response(&id, req.into()).await { Ok(_) => Ok(ok_response), Err(e) => Err(HandlerError::BadRequest(e.into())), @@ -419,6 +506,8 @@ impl Service { Err(_) => return Ok(bad_request_body_res), }; + self.check_v1_blocklist(&body_str).await?; + let v2_compat_body = format!("{body_str}\n{query}"); let id = ShortId::from_str(id)?; handle_peek( @@ -562,6 +651,8 @@ enum HandlerError { SenderGone(anyhow::Error), OhttpKeyRejection(anyhow::Error), BadRequest(anyhow::Error), + /// V1 PSBT rejected — returns the BIP78 `original-psbt-rejected` error. + V1PsbtRejected(anyhow::Error), Forbidden(anyhow::Error), } @@ -595,6 +686,11 @@ impl HandlerError { warn!("Bad request: {}", e); *res.status_mut() = StatusCode::BAD_REQUEST } + HandlerError::V1PsbtRejected(e) => { + warn!("PSBT rejected: {}", e); + *res.status_mut() = StatusCode::BAD_REQUEST; + *res.body_mut() = full(V1_REJECT_RES_JSON); + } HandlerError::Forbidden(e) => { warn!("Forbidden: {}", e); *res.status_mut() = StatusCode::FORBIDDEN @@ -629,6 +725,57 @@ fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()).map_err(|never| match never {}).boxed() } +enum ScreenResult { + Blocked, + Clean, + ParseError(String), +} + +fn screen_v1_addresses( + body: &str, + blocked: &std::collections::HashSet, +) -> ScreenResult { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + use bitcoin::psbt::Psbt; + + let psbt_bytes = match BASE64_STANDARD.decode(body) { + Ok(b) => b, + Err(e) => return ScreenResult::ParseError(format!("base64 decode: {e}")), + }; + + let psbt = match Psbt::deserialize(&psbt_bytes) { + Ok(p) => p, + Err(e) => return ScreenResult::ParseError(format!("PSBT deserialize: {e}")), + }; + + // Check output scripts + for txout in &psbt.unsigned_tx.output { + if blocked.contains(&txout.script_pubkey) { + return ScreenResult::Blocked; + } + } + + // Check input scripts from witness_utxo and non_witness_utxo + for (i, input) in psbt.inputs.iter().enumerate() { + if let Some(ref utxo) = input.witness_utxo { + if blocked.contains(&utxo.script_pubkey) { + return ScreenResult::Blocked; + } + } + if let Some(ref tx) = input.non_witness_utxo { + if let Some(prev_out) = psbt.unsigned_tx.input.get(i) { + if let Some(txout) = tx.output.get(prev_out.previous_output.vout as usize) { + if blocked.contains(&txout.script_pubkey) { + return ScreenResult::Blocked; + } + } + } + } + } + + ScreenResult::Clean +} + #[cfg(test)] mod tests { use std::time::Duration; @@ -641,12 +788,12 @@ mod tests { use super::*; - async fn test_service(enable_v1: bool) -> Service { + async fn test_service(v1: Option) -> Service { let dir = tempfile::tempdir().expect("tempdir"); let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init"); let ohttp: ohttp::Server = key_config::gen_ohttp_server_config().expect("ohttp config").into(); - Service::new(db, ohttp, SentinelTag::new([0u8; 32]), enable_v1) + Service::new(db, ohttp, SentinelTag::new([0u8; 32]), v1) } /// A valid ShortId encoded as bech32 for use in URL paths. @@ -661,9 +808,11 @@ mod tests { (parts.status, String::from_utf8(bytes.to_vec()).unwrap()) } + // V1 routing + #[tokio::test] async fn post_v1_when_disabled_returns_version_unsupported() { - let mut svc = test_service(false).await; + let mut svc = test_service(None).await; let id = valid_short_id_path(); let req = Request::builder() .method(Method::POST) @@ -680,7 +829,7 @@ mod tests { #[tokio::test] async fn post_v1_with_invalid_body_returns_reject() { - let mut svc = test_service(true).await; + let mut svc = test_service(Some(V1::new(None))).await; let id = valid_short_id_path(); let req = Request::builder() .method(Method::POST) @@ -697,7 +846,7 @@ mod tests { #[tokio::test] async fn post_v1_with_no_receiver_returns_unavailable() { - let mut svc = test_service(true).await; + let mut svc = test_service(Some(V1::new(None))).await; let id = valid_short_id_path(); let req = Request::builder() .method(Method::POST) @@ -711,4 +860,107 @@ mod tests { assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); assert_eq!(body, V1_UNAVAILABLE_RES_JSON); } + + // Address screening + + fn make_test_psbt_base64(output_address: &str) -> String { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + use bitcoin::psbt::Psbt; + use bitcoin::{Amount, Transaction, TxIn, TxOut}; + + let addr: bitcoin::Address = + output_address.parse().expect("valid address"); + let script_pubkey = addr.assume_checked().script_pubkey(); + + let tx = Transaction { + version: bitcoin::transaction::Version::TWO, + lock_time: bitcoin::blockdata::locktime::absolute::LockTime::ZERO, + input: vec![TxIn::default()], + output: vec![TxOut { value: Amount::from_sat(50_000), script_pubkey }], + }; + + let psbt = Psbt::from_unsigned_tx(tx).expect("valid psbt"); + BASE64_STANDARD.encode(psbt.serialize()) + } + + fn addr_to_script(address: &str) -> bitcoin::ScriptBuf { + let addr: bitcoin::Address = + address.parse().expect("valid address"); + addr.assume_checked().script_pubkey() + } + + #[tokio::test] + async fn post_v1_with_blocked_address_returns_bad_request() { + let blocked_addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let blocked = BlockedAddresses::from_address_lines(blocked_addr); + let mut svc = test_service(Some(V1::new(Some(blocked)))).await; + let id = valid_short_id_path(); + let psbt_b64 = make_test_psbt_base64(blocked_addr); + let req = Request::builder() + .method(Method::POST) + .uri(format!("http://localhost/{id}")) + .body(Full::new(Bytes::from(psbt_b64))) + .unwrap(); + + let res = tower::Service::call(&mut svc, req).await.unwrap(); + let (status, body) = collect_body(res).await; + + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(body, V1_REJECT_RES_JSON); + } + + #[test] + fn screen_blocks_blocked_output_address() { + let blocked_addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let blocked = std::collections::HashSet::from([addr_to_script(blocked_addr)]); + + let psbt_b64 = make_test_psbt_base64(blocked_addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Blocked)); + } + + #[test] + fn screen_allows_clean_psbt() { + let clean_addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let blocked = std::collections::HashSet::new(); // empty + let psbt_b64 = make_test_psbt_base64(clean_addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Clean)); + } + + #[test] + fn screen_allows_non_blocked_address() { + let addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let blocked = + std::collections::HashSet::from([addr_to_script("3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy")]); + + let psbt_b64 = make_test_psbt_base64(addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Clean)); + } + + #[test] + fn screen_parse_error_on_invalid_base64() { + let blocked = + std::collections::HashSet::from([addr_to_script("1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa")]); + assert!(matches!( + screen_v1_addresses("not-valid-base64!!!", &blocked), + ScreenResult::ParseError(_) + )); + } + + #[test] + fn screen_parse_error_on_invalid_psbt() { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + let blocked = + std::collections::HashSet::from([addr_to_script("1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa")]); + let bad_psbt = BASE64_STANDARD.encode(b"not a psbt"); + assert!(matches!(screen_v1_addresses(&bad_psbt, &blocked), ScreenResult::ParseError(_))); + } + + #[test] + fn screen_blocks_bech32_address() { + let addr = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh"; + let blocked = std::collections::HashSet::from([addr_to_script(addr)]); + + let psbt_b64 = make_test_psbt_base64(addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Blocked)); + } } diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index e60483f3d..2d7ae3554 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -29,7 +29,8 @@ async fn main() -> Result<(), BoxError> { .await .expect("Failed to initialize persistent storage"); - let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), config.enable_v1); + let v1 = if config.enable_v1 { Some(V1::new(None)) } else { None }; + let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), v1); let listener = TcpListener::bind(config.listen_addr).await?; diff --git a/payjoin-mailroom/Cargo.toml b/payjoin-mailroom/Cargo.toml index 5ff4c54e0..4b5ba8448 100644 --- a/payjoin-mailroom/Cargo.toml +++ b/payjoin-mailroom/Cargo.toml @@ -22,6 +22,7 @@ acme = [ "dep:rustls", "dep:tokio-stream", ] +access-control = ["dep:flate2", "dep:ipnet", "dep:maxminddb", "dep:reqwest"] telemetry = ["dep:opentelemetry-otlp"] [dependencies] @@ -32,6 +33,9 @@ axum-server = { version = "0.8", features = [ ], optional = true } clap = { version = "4.5", features = ["derive", "env"] } config = "0.15" +flate2 = { version = "1.1", optional = true } +ipnet = { version = "2", optional = true } +maxminddb = { version = "0.27", optional = true } ohttp-relay = { path = "../ohttp-relay", features = ["bootstrap"] } opentelemetry = "0.31" opentelemetry-otlp = { version = "0.31", optional = true, features = [ @@ -40,6 +44,9 @@ opentelemetry-otlp = { version = "0.31", optional = true, features = [ opentelemetry_sdk = "0.31" payjoin-directory = { path = "../payjoin-directory" } rand = "0.8" +reqwest = { version = "0.12", default-features = false, features = [ + "rustls-tls", +], optional = true } rustls = { version = "0.23", default-features = false, features = [ "ring", ], optional = true } diff --git a/payjoin-mailroom/README.md b/payjoin-mailroom/README.md index 11915626a..19fc60b7a 100644 --- a/payjoin-mailroom/README.md +++ b/payjoin-mailroom/README.md @@ -34,21 +34,21 @@ nix run .#payjoin-mailroom -- --config payjoin-mailroom/config.toml ## Telemetry -payjoin-mailroom supports **optional** OpenTelemetry-based telemetry (metrics, traces, and logs). Build with `--features telemetry` and add a `[telemetry]` section to your config: +payjoin-mailroom supports **optional** OpenTelemetry-based telemetry (metrics). +Build with `--features telemetry` and configure via the [`[telemetry]`](config.example.com) config section. +When no telemetry configuration is present, it falls back to local-only console tracing. -```toml -[telemetry] -endpoint = "https://otlp-gateway-prod-us-west-0.grafana.net/otlp" -auth_token = "" -operator_domain = "your-domain.example.com" -``` +## Access Control -Or set the equivalent environment variables: +Build with `--features access-control` to enable: -```sh -export PJ_TELEMETRY__ENDPOINT="https://otlp-gateway-prod-us-west-0.grafana.net/otlp" -export PJ_TELEMETRY__AUTH_TOKEN="" -export PJ_TELEMETRY__OPERATOR_DOMAIN="your-domain.example.com" -``` +### IP Screening + +Configured via the [`[access_control]`](config.example.toml) config section for IP- and region-based filtering. + +The auto-fetched GeoLite2 database is provided by [MaxMind](https://www.maxmind.com) and distributed under the [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0/) license. + +### V1 Address Screening -When no `[telemetry]` section is present and no `PJ_TELEMETRY__*` variables are set, it falls back to local-only console tracing. +When the V1 protocol is enabled, payjoin-mailroom can screen PSBTs for blocked Bitcoin addresses. +Configure a local blocklist, a remote URL, or both via the [`[v1]`](config.example.toml) config section. diff --git a/payjoin-mailroom/config.example.toml b/payjoin-mailroom/config.example.toml new file mode 100644 index 000000000..6a9e995fe --- /dev/null +++ b/payjoin-mailroom/config.example.toml @@ -0,0 +1,55 @@ +# Payjoin Mailroom configuration example +# +# Configuration can also be set via environment variables with the `PJ_` +# prefix. Nested values use double underscores as separators, e.g. +# PJ_TELEMETRY__OPERATOR_DOMAIN="your-domain.example.com" + +# Address and port to listen on +# listener = "[::]:8080" + +# Directory for persistent storage (OHTTP keys, caches, etc.) +# storage_dir = "./data" + +# Request timeout in seconds +# timeout = 30 + +# --- Telemetry (requires `--telemetry` feature) --- +# [telemetry] + +# OpenTelemetry Protocol (OTLP) endpoint to export telemetry to +# endpoint = "https://otlp-gateway-prod-us-west-0.grafana.net/otlp" + +# Authentication token for the OTLP endpoint +# auth_token = "" + +# The domain you are running the payjoin-mailroom from. +# This serves as an identifier for metrics collection. +# operator_domain = "your-domain.example.com" + +# --- Access-control (requires `access-control` feature) --- +# [access_control] + +# Optional path to a MaxMind GeoIP2 / GeoLite2 Country database. +# If omitted but blocked_regions is non-empty, the free GeoLite2-Country +# database will be fetched automatically and cached under storage_dir. +# geo_db_path = "/path/to/GeoIP2-Country.mmdb" + +# ISO 3166-1 alpha-2 country codes whose requests should be blocked. +# blocked_regions = ["CU", "IR", "KP", "SY"] + +# IP addresses or CIDR ranges whose requests should be blocked. +# blocked_ips = ["192.0.2.0/24", "2001:db8::1"] + +# --- V1 protocol --- +# Uncomment the [v1] section to enable V1 fallback support. +# (address screening requires `access-control` feature) +# [v1] + +# Path to a local file containing blocked Bitcoin addresses (one per line). +# blocked_addresses_path = "/path/to/blocked_addresses.txt" + +# URL to periodically fetch an updated blocked-address list from. +# blocked_addresses_url = "https://example.com/blocked_addresses.txt" + +# How often (in seconds) to refresh the remote address list (default: 86400). +# blocked_addresses_refresh_secs = 86400 diff --git a/payjoin-mailroom/src/access_control.rs b/payjoin-mailroom/src/access_control.rs new file mode 100644 index 000000000..f8bc39594 --- /dev/null +++ b/payjoin-mailroom/src/access_control.rs @@ -0,0 +1,236 @@ +use std::collections::HashSet; +use std::net::IpAddr; +use std::path::Path; + +use maxminddb::PathElement; + +use crate::config::AccessControlConfig; + +pub struct IpFilter { + geo_reader: Option>>, + blocked_regions: HashSet, + blocked_ips: Vec, +} + +impl IpFilter { + pub async fn from_config( + config: &AccessControlConfig, + storage_dir: &Path, + ) -> anyhow::Result { + let geo_reader = match &config.geo_db_path { + Some(path) => Some(maxminddb::Reader::open_readfile(path)?), + None if !config.blocked_regions.is_empty() => { + let cached = storage_dir.join("access-control/geoip.mmdb"); + if cached.exists() { + match maxminddb::Reader::open_readfile(&cached) { + Ok(reader) => Some(reader), + Err(e) => { + tracing::warn!( + "Failed to open cached GeoIP database at {}: {e}; attempting refresh", + cached.display() + ); + fetch_geoip_db(&cached).await?; + Some(maxminddb::Reader::open_readfile(&cached)?) + } + } + } else { + fetch_geoip_db(&cached).await?; + Some(maxminddb::Reader::open_readfile(&cached)?) + } + } + None => None, + }; + + let blocked_regions = config.blocked_regions.iter().cloned().collect(); + + let blocked_ips = config + .blocked_ips + .iter() + .map(|s| { + s.parse::().or_else(|_| { + // Accept bare IP addresses without CIDR prefix length + Ok(ipnet::IpNet::from(s.parse::()?)) + }) + }) + .collect::, anyhow::Error>>()?; + + Ok(Self { geo_reader, blocked_regions, blocked_ips }) + } + + /// Returns `true` if the IP is allowed. Fail-open on GeoIP lookup errors. + pub fn check_ip(&self, ip: IpAddr) -> bool { + if self.blocked_ips.iter().any(|net| net.contains(&ip)) { + return false; + } + + self.check_geoip(ip) + } + + fn check_geoip(&self, ip: IpAddr) -> bool { + let reader = match &self.geo_reader { + Some(r) => r, + None => return true, + }; + + if self.blocked_regions.is_empty() { + return true; + } + + match reader.lookup(ip) { + Ok(result) => { + match result.decode_path::(&[ + PathElement::Key("country"), + PathElement::Key("iso_code"), + ]) { + Ok(Some(iso_code)) => !self.blocked_regions.contains(&iso_code), + _ => true, // no country info or decode error -> allow + } + } + Err(_) => true, // fail-open + } + } +} + +pub fn load_blocked_address_text(path: &Path) -> anyhow::Result { + Ok(std::fs::read_to_string(path)?) +} + +pub fn spawn_address_list_updater( + url: String, + refresh: std::time::Duration, + cache_path: std::path::PathBuf, + blocked: payjoin_directory::BlockedAddresses, +) { + tokio::spawn(async move { + loop { + match reqwest::get(&url).await.and_then(|r| r.error_for_status()) { + Ok(resp) => match resp.text().await { + Ok(body) => { + if let Err(e) = std::fs::write(&cache_path, &body) { + tracing::warn!("Failed to write address cache: {e}"); + } + let count = blocked.update_from_lines(&body).await; + tracing::info!("Updated blocked address list ({count} entries)"); + } + Err(e) => tracing::warn!("Failed to read address list response: {e}"), + }, + Err(e) => tracing::warn!("Failed to fetch address list: {e}"), + } + tokio::time::sleep(refresh).await; + } + }); +} + +async fn fetch_geoip_db(dest: &Path) -> anyhow::Result<()> { + use std::io::Read; + + let url = "https://cdn.jsdelivr.net/npm/geolite2-country/GeoLite2-Country.mmdb.gz"; + tracing::info!("Fetching GeoIP database from {}", url); + + let response = reqwest::get(url).await?; + if !response.status().is_success() { + anyhow::bail!("Failed to fetch GeoIP database: HTTP {}", response.status()); + } + let compressed = response.bytes().await?; + let mut decoder = flate2::read::GzDecoder::new(&compressed[..]); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed)?; + + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent)?; + } + std::fs::write(dest, &decompressed)?; + tracing::info!("GeoIP database saved to {}", dest.display()); + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_geo_reader() -> maxminddb::Reader> { + maxminddb::Reader::open_readfile(concat!( + env!("CARGO_MANIFEST_DIR"), + "/test-data/GeoIP2-Country-Test.mmdb" + )) + .unwrap() + } + + #[test] + fn check_ip_allows_when_no_geo_reader() { + let ac = + IpFilter { geo_reader: None, blocked_regions: HashSet::new(), blocked_ips: vec![] }; + assert!(ac.check_ip("1.2.3.4".parse().unwrap())); + } + + #[test] + fn check_ip_allows_when_no_blocked_regions() { + let reader = test_geo_reader(); + let ac = IpFilter { + geo_reader: Some(reader), + blocked_regions: HashSet::new(), + blocked_ips: vec![], + }; + assert!(ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_blocks_blocked_region() { + let reader = test_geo_reader(); + // 2.125.160.216 is GB in the test database + let blocked_regions: HashSet = ["GB"].iter().map(|s| s.to_string()).collect(); + let ac = IpFilter { geo_reader: Some(reader), blocked_regions, blocked_ips: vec![] }; + assert!(!ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_allows_non_blocked_region() { + let reader = test_geo_reader(); + // 2.125.160.216 is GB in the test database + let blocked_regions: HashSet = ["US"].iter().map(|s| s.to_string()).collect(); + let ac = IpFilter { geo_reader: Some(reader), blocked_regions, blocked_ips: vec![] }; + assert!(ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_fail_open_on_unknown_ip() { + let reader = test_geo_reader(); + let blocked_regions: HashSet = ["US"].iter().map(|s| s.to_string()).collect(); + let ac = IpFilter { geo_reader: Some(reader), blocked_regions, blocked_ips: vec![] }; + // 127.0.0.1 won't be in test DB + assert!(ac.check_ip("127.0.0.1".parse().unwrap())); + } + + #[test] + fn blocked_ips_blocks_exact_ipv4() { + let blocked_ips = vec!["192.0.2.1/32".parse().unwrap()]; + let ac = IpFilter { geo_reader: None, blocked_regions: HashSet::new(), blocked_ips }; + assert!(!ac.check_ip("192.0.2.1".parse().unwrap())); + assert!(ac.check_ip("192.0.2.2".parse().unwrap())); + } + + #[test] + fn blocked_ips_blocks_exact_ipv6() { + let blocked_ips = vec!["2001:db8::1/128".parse().unwrap()]; + let ac = IpFilter { geo_reader: None, blocked_regions: HashSet::new(), blocked_ips }; + assert!(!ac.check_ip("2001:db8::1".parse().unwrap())); + assert!(ac.check_ip("2001:db8::2".parse().unwrap())); + } + + #[test] + fn blocked_ips_blocks_cidr_range() { + let blocked_ips = vec!["198.51.100.0/24".parse().unwrap()]; + let ac = IpFilter { geo_reader: None, blocked_regions: HashSet::new(), blocked_ips }; + assert!(!ac.check_ip("198.51.100.0".parse().unwrap())); + assert!(!ac.check_ip("198.51.100.255".parse().unwrap())); + assert!(ac.check_ip("198.51.101.0".parse().unwrap())); + } + + #[test] + fn empty_blocked_ips_allows_all() { + let ac = + IpFilter { geo_reader: None, blocked_regions: HashSet::new(), blocked_ips: vec![] }; + assert!(ac.check_ip("192.0.2.1".parse().unwrap())); + assert!(ac.check_ip("2001:db8::1".parse().unwrap())); + } +} diff --git a/payjoin-mailroom/src/config.rs b/payjoin-mailroom/src/config.rs index 813f22a04..a040445ae 100644 --- a/payjoin-mailroom/src/config.rs +++ b/payjoin-mailroom/src/config.rs @@ -12,11 +12,28 @@ pub struct Config { pub storage_dir: PathBuf, #[serde(deserialize_with = "deserialize_duration_secs")] pub timeout: Duration, - pub enable_v1: bool, + pub v1: Option, #[cfg(feature = "telemetry")] pub telemetry: Option, #[cfg(feature = "acme")] pub acme: Option, + #[cfg(feature = "access-control")] + pub access_control: Option, +} + +/// V1 protocol configuration. +/// +/// Present in [`Config`] to enable the V1 fallback path. +/// Contains optional address-screening settings that only apply to V1. +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default)] +pub struct V1Config { + #[cfg(feature = "access-control")] + pub blocked_addresses_path: Option, + #[cfg(feature = "access-control")] + pub blocked_addresses_url: Option, + #[cfg(feature = "access-control")] + pub blocked_addresses_refresh_secs: Option, } #[cfg(feature = "telemetry")] @@ -36,6 +53,15 @@ pub struct AcmeConfig { pub directory_url: Option, } +#[cfg(feature = "access-control")] +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default)] +pub struct AccessControlConfig { + pub geo_db_path: Option, + pub blocked_regions: Vec, + pub blocked_ips: Vec, +} + #[cfg(feature = "acme")] impl AcmeConfig { pub fn into_rustls_config( @@ -59,11 +85,13 @@ impl Default for Config { listener: "[::]:8080".parse().expect("valid default listener address"), storage_dir: PathBuf::from("./data"), timeout: Duration::from_secs(30), - enable_v1: false, + v1: None, #[cfg(feature = "telemetry")] telemetry: None, #[cfg(feature = "acme")] acme: None, + #[cfg(feature = "access-control")] + access_control: None, } } } @@ -81,17 +109,19 @@ impl Config { listener: ListenerAddress, storage_dir: PathBuf, timeout: Duration, - enable_v1: bool, + v1: Option, ) -> Self { Self { listener, storage_dir, timeout, - enable_v1, + v1, #[cfg(feature = "telemetry")] telemetry: None, #[cfg(feature = "acme")] acme: None, + #[cfg(feature = "access-control")] + access_control: None, } } @@ -109,6 +139,8 @@ impl Config { .list_separator(",") .with_list_parse_key("acme.domains") .with_list_parse_key("acme.contact") + .with_list_parse_key("access_control.blocked_regions") + .with_list_parse_key("access_control.blocked_ips") .try_parsing(true), ) .build()? diff --git a/payjoin-mailroom/src/lib.rs b/payjoin-mailroom/src/lib.rs index a95698209..08273bdba 100644 --- a/payjoin-mailroom/src/lib.rs +++ b/payjoin-mailroom/src/lib.rs @@ -1,6 +1,10 @@ +#[cfg(feature = "access-control")] +use axum::extract::connect_info::Connected; use axum::extract::State; use axum::http::Method; use axum::response::{IntoResponse, Response}; +#[cfg(feature = "access-control")] +use axum::serve::IncomingStream; use axum::Router; use config::Config; use ohttp_relay::SentinelTag; @@ -10,6 +14,8 @@ use tokio_listener::{Listener, SystemOptions, UserOptions}; use tower::{Service, ServiceBuilder}; use tracing::info; +#[cfg(feature = "access-control")] +pub mod access_control; pub mod cli; pub mod config; pub mod metrics; @@ -23,18 +29,29 @@ struct Services { directory: payjoin_directory::Service, relay: ohttp_relay::Service, metrics: MetricsService, + #[cfg(feature = "access-control")] + geoip: Option>, } pub async fn serve(config: Config, meter_provider: Option) -> anyhow::Result<()> { let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let geoip = init_geoip(&config).await?; + + let directory = init_directory(&config, sentinel_tag).await?; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory, relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(meter_provider), + #[cfg(feature = "access-control")] + geoip, }; let app = build_app(services); + #[cfg(feature = "access-control")] + let app = app.into_make_service_with_connect_info::(); let listener = Listener::bind(&config.listener, &SystemOptions::default(), &UserOptions::default()) @@ -62,10 +79,17 @@ pub async fn serve_manual_tls( let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let geoip = init_geoip(&config).await?; + + let directory = init_directory(&config, sentinel_tag).await?; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory, relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await, metrics: MetricsService::new(None), + #[cfg(feature = "access-control")] + geoip, }; let app = build_app(services); @@ -82,14 +106,18 @@ pub async fn serve_manual_tls( info!("Payjoin service listening on port {} with TLS", port); tokio::spawn(async move { axum_server::from_tcp_rustls(listener.into_std()?, tls)? - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await .map_err(Into::into) }) } None => { info!("Payjoin service listening on port {} without TLS", port); - tokio::spawn(async move { axum::serve(listener, app).await.map_err(Into::into) }) + tokio::spawn(async move { + axum::serve(listener, app.into_make_service_with_connect_info::()) + .await + .map_err(Into::into) + }) } }; @@ -115,10 +143,17 @@ pub async fn serve_acme( let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let geoip = init_geoip(&config).await?; + + let directory = init_directory(&config, sentinel_tag).await?; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory, relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(meter_provider), + #[cfg(feature = "access-control")] + geoip, }; let app = build_app(services); @@ -148,7 +183,10 @@ pub async fn serve_acme( }); info!("Payjoin service listening on {} with ACME TLS", addr); - axum_server::bind(addr).acceptor(acceptor).serve(app.into_make_service()).await?; + axum_server::bind(addr) + .acceptor(acceptor) + .serve(app.into_make_service_with_connect_info::()) + .await?; Ok(()) } @@ -157,6 +195,17 @@ pub async fn serve_acme( /// at detecting self loops. fn generate_sentinel_tag() -> SentinelTag { SentinelTag::new(rand::thread_rng().gen()) } +#[cfg(feature = "access-control")] +impl Connected> for middleware::MaybePeerIp { + fn connect_info(stream: IncomingStream<'_, Listener>) -> Self { + let ip = match stream.remote_addr() { + tokio_listener::SomeSocketAddr::Tcp(addr) => Some(addr.ip()), + _ => None, + }; + Self(ip) + } +} + async fn init_directory( config: &Config, sentinel_tag: SentinelTag, @@ -167,7 +216,110 @@ async fn init_directory( let ohttp_keys_dir = config.storage_dir.join("ohttp-keys"); let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?; - Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag, config.enable_v1)) + let v1 = if config.v1.is_some() { + #[cfg(feature = "access-control")] + let blocked = init_blocked_addresses(config).await?; + #[cfg(not(feature = "access-control"))] + let blocked = None; + Some(payjoin_directory::V1::new(blocked)) + } else { + None + }; + Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag, v1)) +} + +#[cfg(feature = "access-control")] +async fn init_geoip( + config: &Config, +) -> anyhow::Result>> { + match &config.access_control { + Some(ac_config) => { + let gi = access_control::IpFilter::from_config(ac_config, &config.storage_dir).await?; + info!("GeoIP access control enabled"); + Ok(Some(std::sync::Arc::new(gi))) + } + None => Ok(None), + } +} + +#[cfg(feature = "access-control")] +async fn init_blocked_addresses( + config: &Config, +) -> anyhow::Result> { + let v1_config = match &config.v1 { + Some(c) => c, + None => return Ok(None), + }; + + // Neither file nor URL configured + if v1_config.blocked_addresses_path.is_none() && v1_config.blocked_addresses_url.is_none() { + return Ok(None); + } + + // Load initial addresses from file if available + let blocked = match &v1_config.blocked_addresses_path { + Some(path) => { + let text = access_control::load_blocked_address_text(path)?; + let ba = payjoin_directory::BlockedAddresses::from_address_lines(&text); + info!("Loaded blocked addresses from {}", path.display()); + ba + } + None => payjoin_directory::BlockedAddresses::empty(), + }; + + // If URL configured, try initial fetch and spawn background updater + if let Some(url) = &v1_config.blocked_addresses_url { + let cache_path = config.storage_dir.join("blocked_addresses_cache.txt"); + let refresh = std::time::Duration::from_secs( + v1_config.blocked_addresses_refresh_secs.unwrap_or(86400), + ); + + // Try initial fetch; fall back to cache on failure + match reqwest::get(url).await.and_then(|r| r.error_for_status()) { + Ok(resp) => match resp.text().await { + Ok(body) => { + if let Err(e) = std::fs::write(&cache_path, &body) { + tracing::warn!("Failed to write address cache: {e}"); + } + let count = blocked.update_from_lines(&body).await; + info!("Fetched {count} blocked addresses from URL"); + } + Err(e) => { + tracing::warn!("Failed to read address list response: {e}"); + load_address_cache(&cache_path, &blocked).await; + } + }, + Err(e) => { + tracing::warn!("Failed to fetch address list: {e}"); + load_address_cache(&cache_path, &blocked).await; + } + } + + access_control::spawn_address_list_updater( + url.clone(), + refresh, + cache_path, + blocked.clone(), + ); + } + + Ok(Some(blocked)) +} + +#[cfg(feature = "access-control")] +async fn load_address_cache( + cache_path: &std::path::Path, + blocked: &payjoin_directory::BlockedAddresses, +) { + if cache_path.exists() { + match access_control::load_blocked_address_text(cache_path) { + Ok(text) => { + let count = blocked.update_from_lines(&text).await; + info!("Loaded {count} blocked addresses from cache"); + } + Err(e) => tracing::warn!("Failed to load address cache: {e}"), + } + } } fn init_ohttp_config( @@ -186,14 +338,28 @@ fn init_ohttp_config( fn build_app(services: Services) -> Router { let metrics = services.metrics.clone(); - Router::new() + + #[cfg(feature = "access-control")] + let geoip = services.geoip.clone(); + + #[allow(unused_mut)] + let mut router = Router::new() .fallback(route_request) .layer( ServiceBuilder::new() .layer(axum::middleware::from_fn_with_state(metrics.clone(), track_metrics)) .layer(axum::middleware::from_fn_with_state(metrics, track_connections)), ) - .with_state(services) + .with_state(services); + + #[cfg(feature = "access-control")] + { + router = router + .layer(axum::middleware::from_fn(middleware::check_geoip)) + .layer(axum::Extension(geoip)); + } + + router } async fn route_request( @@ -260,7 +426,7 @@ mod tests { "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), - false, + None, ); let mut root_store = RootCertStore::empty(); @@ -355,7 +521,7 @@ mod tests { "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), - false, + None, ); let sentinel_tag = generate_sentinel_tag(); @@ -363,6 +529,8 @@ mod tests { directory: init_directory(&config, sentinel_tag).await.unwrap(), relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(Some(provider.clone())), + #[cfg(feature = "access-control")] + geoip: None, }; let app = build_app(services); diff --git a/payjoin-mailroom/src/middleware.rs b/payjoin-mailroom/src/middleware.rs index 875bc49b4..786337398 100644 --- a/payjoin-mailroom/src/middleware.rs +++ b/payjoin-mailroom/src/middleware.rs @@ -4,6 +4,35 @@ use axum::response::Response; use crate::metrics::MetricsService; +#[cfg(feature = "access-control")] +#[derive(Clone, Debug)] +pub struct MaybePeerIp(pub Option); + +#[cfg(feature = "access-control")] +pub async fn check_geoip(req: Request, next: Next) -> Response { + use axum::http::StatusCode; + + let geoip = req.extensions().get::>>(); + + if let Some(Some(geoip)) = geoip { + if let Some(connect_info) = + req.extensions().get::>() + { + if let Some(ip) = connect_info.0 .0 { + if !geoip.check_ip(ip) { + tracing::warn!("Blocked request from {ip} due to GeoIP policy"); + return Response::builder() + .status(StatusCode::FORBIDDEN) + .body(axum::body::Body::empty()) + .expect("valid response"); + } + } + } + } + + next.run(req).await +} + pub async fn track_metrics( metrics: axum::extract::State, req: Request, diff --git a/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb b/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb new file mode 100644 index 000000000..840f89384 Binary files /dev/null and b/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb differ diff --git a/payjoin-test-utils/src/lib.rs b/payjoin-test-utils/src/lib.rs index e3b8fcece..9c467ca29 100644 --- a/payjoin-test-utils/src/lib.rs +++ b/payjoin-test-utils/src/lib.rs @@ -121,7 +121,7 @@ pub async fn init_directory( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), - true, + Some(payjoin_mailroom::config::V1Config::default()), ); let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?; @@ -149,7 +149,7 @@ async fn init_ohttp_relay( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), - false, + None, ); let (port, handle) = payjoin_mailroom::serve_manual_tls(config, None, root_store)