diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index a958e7350..48d3be19f 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -2090,6 +2090,12 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +[[package]] +name = "ipnetwork" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" + [[package]] name = "iri-string" version = "0.7.8" @@ -2280,6 +2286,19 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "maxminddb" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e371ee70dfbe063e098d1f90f01eee1458db7b0d7c03cd01e95453aa0e04e6" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2777,6 +2796,8 @@ dependencies = [ "axum-server", "clap", "config", + "flate2", + "maxminddb", "ohttp-relay", "opentelemetry", "opentelemetry-otlp", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index a958e7350..48d3be19f 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -2090,6 +2090,12 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "187674a687eed5fe42285b40c6291f9a01517d415fad1c3cbc6a9f778af7fcd4" +[[package]] +name = "ipnetwork" +version = "0.21.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf370abdafd54d13e54a620e8c3e1145f28e46cc9d704bc6d94414559df41763" + [[package]] name = "iri-string" version = "0.7.8" @@ -2280,6 +2286,19 @@ version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47e1ffaa40ddd1f3ed91f717a33c8c0ee23fff369e3aa8772b9605cc1d22f4c3" +[[package]] +name = "maxminddb" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d5e371ee70dfbe063e098d1f90f01eee1458db7b0d7c03cd01e95453aa0e04e6" +dependencies = [ + "ipnetwork", + "log", + "memchr", + "serde", + "thiserror 2.0.17", +] + [[package]] name = "memchr" version = "2.7.4" @@ -2777,6 +2796,8 @@ dependencies = [ "axum-server", "clap", "config", + "flate2", + "maxminddb", "ohttp-relay", "opentelemetry", "opentelemetry-otlp", diff --git a/flake.nix b/flake.nix index 8f42dec71..2e5bcfc00 100644 --- a/flake.nix +++ b/flake.nix @@ -85,6 +85,7 @@ filter = path: type: (builtins.match ".*nginx.conf.template$" path != null) + || (builtins.match ".*/payjoin-mailroom/test-data/.*" path != null) || (craneLibVersions.msrv.filterCargoSources path type); name = "source"; }; diff --git a/payjoin-directory/src/db/files.rs b/payjoin-directory/src/db/files.rs index 42d9c34d3..f98032931 100644 --- a/payjoin-directory/src/db/files.rs +++ b/payjoin-directory/src/db/files.rs @@ -37,7 +37,9 @@ struct V2WaitMapEntry { #[derive(Debug)] struct V1WaitMapEntry { - payload: Arc>, + /// The V1 payload. `take()`n after the first read for data minimization — + /// plaintext PSBTs should not linger in memory longer than needed. + payload: Option>>, sender: oneshot::Sender>, } @@ -325,9 +327,12 @@ impl DbTrait for Db { impl Mailboxes { async fn read(&mut self, id: &ShortId) -> io::Result>>> { // V1 POST requests are only stored in memory since they are - // unencrypted. Check this hash table first. - if let Some(V1WaitMapEntry { payload, .. }) = self.pending_v1.get(id) { - return Ok(Some(payload.clone())); + // unencrypted. Check this hash table first. Use take() for data + // minimization — clear the plaintext payload after first read. + if let Some(entry) = self.pending_v1.get_mut(id) { + if let Some(payload) = entry.payload.take() { + return Ok(Some(payload)); + } } // V2 requests are stored on disk @@ -358,8 +363,11 @@ impl Mailboxes { return Err(Error::OverCapacity); } - if self.pending_v1.contains_key(id) { - return Err(Error::OverCapacity); + if let Some(entry) = self.pending_v1.get(id) { + if entry.payload.is_some() { + return Err(Error::OverCapacity); + } + return Err(Error::AlreadyRead); } let receiver = self @@ -419,13 +427,17 @@ impl Mailboxes { let payload = payload.clone(); let (sender, receiver) = oneshot::channel::>(); ret = Some(receiver); - V1WaitMapEntry { payload, sender } + V1WaitMapEntry { payload: Some(payload), sender } }); - // If there are pending readers, satisfy them and mark the payload as read + // If there are pending readers, satisfy them with the payload + // and clear the in-memory copy for data minimization if let Some(pending) = self.pending_v2.remove(id) { trace!("notifying pending readers for {} (v1 fallback)", id); - pending.sender.send(payload).expect("sending on oneshot channel must succeed"); + pending.sender.send(payload.clone()).expect("sending on oneshot channel must succeed"); + if let Some(entry) = self.pending_v1.get_mut(id) { + entry.payload.take(); + } } Ok(ret) @@ -568,6 +580,9 @@ pub enum Error { /// Operation rejected due to lack of capacity OverCapacity, + /// Indicates receiver already consumed the plaintext V1 request payload + AlreadyRead, + /// Indicates the sender that was waiting for the reply is no longer there V1SenderUnavailable, @@ -584,6 +599,7 @@ impl From for super::Error { match val { Error::V1SenderUnavailable => super::Error::V1SenderUnavailable, Error::OverCapacity => super::Error::OverCapacity, + Error::AlreadyRead => super::Error::AlreadyRead, Error::IO(e) => super::Error::Operational(e), } } @@ -603,6 +619,7 @@ impl std::fmt::Display for Error { use Error::*; match self { OverCapacity => "Database over capacity".fmt(f), + AlreadyRead => "Mailbox payload already read".fmt(f), V1SenderUnavailable => "Sender no longer connected".fmt(f), IO(e) => write!(f, "Internal Error: {e}"), } @@ -780,7 +797,7 @@ async fn test_v2_wait() -> std::io::Result<()> { match db.wait_for_v2_payload(&id).await { Err(super::Error::Timeout(_)) => {} - res => panic!("expected timeout, got {:?}", res), + res => panic!("expected timeout, got {res:?}"), } let read_task1 = tokio::spawn({ @@ -870,6 +887,59 @@ async fn test_v1_wait() -> std::io::Result<()> { Ok(()) } +#[tokio::test] +async fn test_v1_data_minimization() -> std::io::Result<()> { + let dir = tempfile::tempdir()?; + + let db = Arc::new( + Db::init(Duration::from_millis(500), dir.path().to_owned()) + .await + .expect("initializing mailbox database should succeed"), + ); + + let id = ShortId([0u8; 8]); + + // Spawn v1 sender in background + let v1_sender_task = tokio::spawn({ + let db = db.clone(); + async move { db.post_v1_request_and_wait_for_response(&id, b"request".to_vec()).await } + }); + + // Small delay to let v1 request post + tokio::time::sleep(Duration::from_millis(10)).await; + + // First read should return the payload + let res = db.wait_for_v2_payload(&id).await.expect("first read should succeed"); + assert_eq!(&res[..], b"request", "first read should return the payload"); + + // Subsequent reads should not return the plaintext payload again. + assert!( + matches!(db.wait_for_v2_payload(&id).await, Err(super::Error::AlreadyRead)), + "subsequent reads should indicate the payload was already consumed" + ); + + // Verify the payload was cleared from memory by checking directly + { + let guard = db.mailboxes.lock().await; + let entry = guard.pending_v1.get(&id); + assert!( + entry.is_none_or(|e| e.payload.is_none()), + "v1 payload should have been cleared after first read" + ); + } + + // V1 response flow should still work even after payload was cleared + db.post_v1_response(&id, b"response".to_vec()).await.expect("posting response should succeed"); + + let res = v1_sender_task + .await + .expect("joining task should succeed") + .expect("v1 sender should get response"); + assert_eq!(&res[..], b"response", "v1 sender should receive the response"); + + Ok(()) +} + // FIXME test is a bit slow and flakey, how to improve? // unfortunately tokio::time::pause() can't be used because this uses SystemTime // as the underlying clock type, due to timestamps originating from disk @@ -906,7 +976,7 @@ async fn test_prune() -> std::io::Result<()> { match read_task1.await.expect("joining should succeed") { Err(super::Error::Timeout(_)) => {} - res => panic!("expected timeout, got {:?}", res), + res => panic!("expected timeout, got {res:?}"), } db.prune().await.expect("pruning should not fail"); diff --git a/payjoin-directory/src/db/mod.rs b/payjoin-directory/src/db/mod.rs index b4970a532..4adfdb296 100644 --- a/payjoin-directory/src/db/mod.rs +++ b/payjoin-directory/src/db/mod.rs @@ -16,6 +16,7 @@ pub enum Error { Operational(OperationalError), Timeout(tokio::time::error::Elapsed), OverCapacity, + AlreadyRead, V1SenderUnavailable, } @@ -33,6 +34,7 @@ impl std::fmt::Display for Error { Operational(error) => write!(f, "Db error: {error}"), Timeout(timeout) => write!(f, "Timeout: {timeout}"), OverCapacity => "Database over capacity".fmt(f), + AlreadyRead => "Mailbox payload already read".fmt(f), V1SenderUnavailable => "Sender no longer connected".fmt(f), } } diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 97b7eea43..28cc0fd56 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -68,6 +68,8 @@ pub struct Service { db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, + blocked_addresses: Option>>>, + v1_disabled: bool, } impl tower::Service> for Service @@ -91,8 +93,14 @@ where } impl Service { - pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag) -> Self { - Self { db, ohttp, sentinel_tag } + pub fn new( + db: D, + ohttp: ohttp::Server, + sentinel_tag: SentinelTag, + blocked_addresses: Option>>>, + v1_disabled: bool, + ) -> Self { + Self { db, ohttp, sentinel_tag, blocked_addresses, v1_disabled } } #[cfg(feature = "_manual-tls")] @@ -375,6 +383,16 @@ impl Service { B::Error: Into, { trace!("Post fallback v1"); + + if self.v1_disabled { + return Ok(Response::builder() + .status(StatusCode::FORBIDDEN) + .header(CONTENT_TYPE, "application/json") + .body(full( + r#"{"errorCode": "v1-disabled", "message": "V1 is disabled on this server"}"#, + ))?); + } + let none_response = Response::builder() .status(StatusCode::SERVICE_UNAVAILABLE) .body(full(V1_UNAVAILABLE_RES_JSON))?; @@ -391,6 +409,24 @@ impl Service { Err(_) => return Ok(bad_request_body_res), }; + if let Some(blocked) = &self.blocked_addresses { + let blocked = blocked.read().await; + if !blocked.is_empty() { + match screen_v1_addresses(&body_str, &blocked) { + ScreenResult::Blocked => { + return Ok(Response::builder() + .status(StatusCode::FORBIDDEN) + .body(empty())?); + } + ScreenResult::Clean => {} + ScreenResult::ParseError(e) => { + warn!("Could not screen V1 payload: {e}"); + // fail-open: unparsable PSBTs can't complete transactions + } + } + } + } + let v2_compat_body = format!("{body_str}\n{query}"); let id = ShortId::from_str(id)?; handle_peek( @@ -438,6 +474,84 @@ impl Service { } } +enum ScreenResult { + Blocked, + Clean, + ParseError(String), +} + +fn screen_v1_addresses(body: &str, blocked: &std::collections::HashSet) -> ScreenResult { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + use bitcoin::psbt::Psbt; + use bitcoin::{Address, Network}; + + // V1 payjoin deployments in use today are Bitcoin mainnet-only. + let network = Network::Bitcoin; + + let psbt_bytes = match BASE64_STANDARD.decode(body) { + Ok(b) => b, + Err(e) => return ScreenResult::ParseError(format!("base64 decode: {e}")), + }; + + let psbt = match Psbt::deserialize(&psbt_bytes) { + Ok(p) => p, + Err(e) => return ScreenResult::ParseError(format!("PSBT deserialize: {e}")), + }; + + // Check output addresses + for txout in &psbt.unsigned_tx.output { + if let Ok(addr) = Address::from_script(&txout.script_pubkey, network) { + if is_blocked_address(blocked, &addr) { + return ScreenResult::Blocked; + } + } + } + + // Check input addresses from witness_utxo and non_witness_utxo + for (i, input) in psbt.inputs.iter().enumerate() { + if let Some(ref utxo) = input.witness_utxo { + if let Ok(addr) = Address::from_script(&utxo.script_pubkey, network) { + if is_blocked_address(blocked, &addr) { + return ScreenResult::Blocked; + } + } + } + if let Some(ref tx) = input.non_witness_utxo { + if let Some(prev_out) = psbt.unsigned_tx.input.get(i) { + if let Some(txout) = tx.output.get(prev_out.previous_output.vout as usize) { + if let Ok(addr) = Address::from_script(&txout.script_pubkey, network) { + if is_blocked_address(blocked, &addr) { + return ScreenResult::Blocked; + } + } + } + } + } + } + + ScreenResult::Clean +} + +fn is_blocked_address( + blocked: &std::collections::HashSet, + addr: &bitcoin::Address, +) -> bool { + let rendered = addr.to_string(); + if blocked.contains(&rendered) { + return true; + } + if is_bech32_address(&rendered) { + return blocked.contains(&rendered.to_ascii_lowercase()) + || blocked.contains(&rendered.to_ascii_uppercase()); + } + false +} + +fn is_bech32_address(addr: &str) -> bool { + let lower = addr.to_ascii_lowercase(); + lower.starts_with("bc1") || lower.starts_with("tb1") || lower.starts_with("bcrt1") +} + fn handle_peek( result: Result>, db::Error>, timeout_response: Response>, @@ -453,6 +567,7 @@ fn handle_peek( db::Error::OverCapacity => Err(HandlerError::ServiceUnavailable(anyhow::Error::msg( "mailbox storage at capacity", ))), + db::Error::AlreadyRead => Ok(timeout_response), db::Error::V1SenderUnavailable => Err(HandlerError::SenderGone(anyhow::Error::msg( "Sender is unavailable try a new request", ))), @@ -599,3 +714,87 @@ fn empty() -> BoxBody { fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()).map_err(|never| match never {}).boxed() } + +#[cfg(test)] +mod screen_tests { + use super::*; + + fn make_test_psbt_base64(output_address: &str) -> String { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + use bitcoin::psbt::Psbt; + use bitcoin::{Address, Amount, Transaction, TxIn, TxOut}; + + let addr: Address = + output_address.parse().expect("valid address"); + let addr = addr.assume_checked(); + + let tx = Transaction { + version: bitcoin::transaction::Version::TWO, + lock_time: bitcoin::blockdata::locktime::absolute::LockTime::ZERO, + input: vec![TxIn::default()], + output: vec![TxOut { + value: Amount::from_sat(50_000), + script_pubkey: addr.script_pubkey(), + }], + }; + + let psbt = Psbt::from_unsigned_tx(tx).expect("valid psbt"); + let serialized = psbt.serialize(); + BASE64_STANDARD.encode(&serialized) + } + + #[test] + fn screen_blocks_blocked_output_address() { + let blocked_addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let mut blocked = std::collections::HashSet::new(); + blocked.insert(blocked_addr.to_string()); + + let psbt_b64 = make_test_psbt_base64(blocked_addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Blocked)); + } + + #[test] + fn screen_allows_clean_psbt() { + let clean_addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let blocked = std::collections::HashSet::new(); // empty + let psbt_b64 = make_test_psbt_base64(clean_addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Clean)); + } + + #[test] + fn screen_allows_non_blocked_address() { + let addr = "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa"; + let mut blocked = std::collections::HashSet::new(); + blocked.insert("3J98t1WpEZ73CNmQviecrnyiWrnqRhWNLy".to_string()); + + let psbt_b64 = make_test_psbt_base64(addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Clean)); + } + + #[test] + fn screen_parse_error_on_invalid_base64() { + let blocked = std::collections::HashSet::from(["test".to_string()]); + assert!(matches!( + screen_v1_addresses("not-valid-base64!!!", &blocked), + ScreenResult::ParseError(_) + )); + } + + #[test] + fn screen_parse_error_on_invalid_psbt() { + use bitcoin::base64::prelude::{Engine, BASE64_STANDARD}; + let blocked = std::collections::HashSet::from(["test".to_string()]); + let bad_psbt = BASE64_STANDARD.encode(b"not a psbt"); + assert!(matches!(screen_v1_addresses(&bad_psbt, &blocked), ScreenResult::ParseError(_))); + } + + #[test] + fn screen_blocks_uppercase_bech32_blocklist_entry() { + let addr = "bc1qxy2kgdygjrsqtzq2n0yrf2493p83kkfjhx0wlh"; + let mut blocked = std::collections::HashSet::new(); + blocked.insert(addr.to_ascii_uppercase()); + + let psbt_b64 = make_test_psbt_base64(addr); + assert!(matches!(screen_v1_addresses(&psbt_b64, &blocked), ScreenResult::Blocked)); + } +} diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 9c68b7831..fe6fa03a5 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), BoxError> { .await .expect("Failed to initialize persistent storage"); - let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32])); + let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), None, false); let listener = TcpListener::bind(config.listen_addr).await?; diff --git a/payjoin-mailroom/Cargo.toml b/payjoin-mailroom/Cargo.toml index 5ff4c54e0..faae3dd1a 100644 --- a/payjoin-mailroom/Cargo.toml +++ b/payjoin-mailroom/Cargo.toml @@ -23,6 +23,7 @@ acme = [ "dep:tokio-stream", ] telemetry = ["dep:opentelemetry-otlp"] +access-control = ["dep:maxminddb", "dep:reqwest", "dep:flate2"] [dependencies] anyhow = "1.0" @@ -32,6 +33,8 @@ axum-server = { version = "0.8", features = [ ], optional = true } clap = { version = "4.5", features = ["derive", "env"] } config = "0.15" +flate2 = { version = "1", optional = true } +maxminddb = { version = "0.27", optional = true } ohttp-relay = { path = "../ohttp-relay", features = ["bootstrap"] } opentelemetry = "0.31" opentelemetry-otlp = { version = "0.31", optional = true, features = [ @@ -40,6 +43,9 @@ opentelemetry-otlp = { version = "0.31", optional = true, features = [ opentelemetry_sdk = "0.31" payjoin-directory = { path = "../payjoin-directory" } rand = "0.8" +reqwest = { version = "0.12", default-features = false, features = [ + "rustls-tls", +], optional = true } rustls = { version = "0.23", default-features = false, features = [ "ring", ], optional = true } diff --git a/payjoin-mailroom/README.md b/payjoin-mailroom/README.md index 11915626a..f86da2e5e 100644 --- a/payjoin-mailroom/README.md +++ b/payjoin-mailroom/README.md @@ -8,6 +8,8 @@ Note that this binary is under active development and thus the CLI and configura payjoin-mailroom reads configuration from `config.toml` (or the path given with `--config`). Every setting can also be supplied via environment variables prefixed with `PJ_`, using double underscores for nesting (e.g., `PJ_TELEMETRY__ENDPOINT`). +A complete example is available at [config.example.toml](config.example.toml), including optional `[access_control]`, `[telemetry]`, and `[acme]` sections. + ## Usage ### Cargo diff --git a/payjoin-mailroom/config.example.toml b/payjoin-mailroom/config.example.toml new file mode 100644 index 000000000..947f3f52e --- /dev/null +++ b/payjoin-mailroom/config.example.toml @@ -0,0 +1,47 @@ +## +## Payjoin Mailroom config.toml configuration file. +## Copy this to config.toml and adjust values for your deployment. +## + +# Core settings +listener = "[::]:8080" +storage_dir = "./data" +timeout = 30 + +# Optional telemetry settings (requires `--features telemetry`) +# +# [telemetry] +# endpoint = "https://otlp-gateway-prod-us-west-0.grafana.net/otlp" +# auth_token = "" +# operator_domain = "your-domain.example.com" + +# Optional ACME settings (requires `--features acme`) +# +# [acme] +# domains = ["mailroom.example.com"] +# contact = ["mailto:ops@example.com"] +# # Optional; defaults to Let's Encrypt production when omitted. +# directory_url = "https://acme-v02.api.letsencrypt.org/directory" + +# Optional access control settings (requires `--features access-control`) +# GeoIP and address screening are opt-in. V1 is deny-by-default and must be +# explicitly enabled. +# +# [access_control] +# # If omitted and blocked_regions is non-empty, DB-IP Lite MMDB is auto-fetched +# # into storage_dir on first startup. +# geo_db_path = "/absolute/path/to/geoip.mmdb" +# +# # ISO 3166-1 alpha-2 region codes to block. +# # Example policy set: CU (Cuba), IR (Iran), KP (North Korea), SY (Syria). +# blocked_regions = ["CU", "IR", "KP", "SY"] +# +# # Optional local file with one address per line. +# blocked_addresses_path = "/absolute/path/to/blocked_addresses.txt" +# +# # Optional URL for auto-updated address blocklist. +# blocked_addresses_url = "https://example.com/blocked_addresses.txt" +# blocked_addresses_refresh_secs = 86400 +# +# # Enable V1 requests. Defaults to false when omitted. +# enable_v1 = true diff --git a/payjoin-mailroom/src/access_control.rs b/payjoin-mailroom/src/access_control.rs new file mode 100644 index 000000000..c0454faba --- /dev/null +++ b/payjoin-mailroom/src/access_control.rs @@ -0,0 +1,252 @@ +use std::collections::HashSet; +use std::net::IpAddr; +use std::path::Path; + +use maxminddb::PathElement; + +use crate::config::AccessControlConfig; + +pub struct AccessControl { + geo_reader: Option>>, + blocked_regions: HashSet, +} + +impl AccessControl { + pub async fn from_config( + config: &AccessControlConfig, + storage_dir: &Path, + ) -> anyhow::Result { + let geo_reader = match &config.geo_db_path { + Some(path) => Some(maxminddb::Reader::open_readfile(path)?), + None if !config.blocked_regions.is_empty() => { + let cached = storage_dir.join("geoip.mmdb"); + if cached.exists() { + match maxminddb::Reader::open_readfile(&cached) { + Ok(reader) => Some(reader), + Err(e) => { + tracing::warn!( + "Failed to open cached GeoIP database at {}: {e}; attempting refresh", + cached.display() + ); + fetch_geoip_db(&cached).await?; + Some(maxminddb::Reader::open_readfile(&cached)?) + } + } + } else { + fetch_geoip_db(&cached).await?; + Some(maxminddb::Reader::open_readfile(&cached)?) + } + } + None => None, + }; + + let blocked_regions = config.blocked_regions.iter().cloned().collect(); + + Ok(Self { geo_reader, blocked_regions }) + } + + /// Returns `true` if the IP is allowed. Fail-open on lookup errors. + pub fn check_ip(&self, ip: IpAddr) -> bool { + let reader = match &self.geo_reader { + Some(r) => r, + None => return true, + }; + + if self.blocked_regions.is_empty() { + return true; + } + + match reader.lookup(ip) { + Ok(result) => { + match result.decode_path::(&[ + PathElement::Key("country"), + PathElement::Key("iso_code"), + ]) { + Ok(Some(iso_code)) => !self.blocked_regions.contains(&iso_code), + _ => true, // no country info or decode error -> allow + } + } + Err(_) => true, // fail-open + } + } +} + +pub fn load_blocked_addresses(path: &Path) -> anyhow::Result> { + let content = std::fs::read_to_string(path)?; + Ok(content + .lines() + .map(|l| normalize_blocked_address(l.trim())) + .filter(|l| !l.is_empty()) + .collect()) +} + +pub fn normalize_blocked_address(addr: &str) -> String { + if is_bech32_address(addr) { + addr.to_ascii_lowercase() + } else { + addr.to_string() + } +} + +fn is_bech32_address(addr: &str) -> bool { + let lower = addr.to_ascii_lowercase(); + lower.starts_with("bc1") || lower.starts_with("tb1") || lower.starts_with("bcrt1") +} + +pub fn spawn_address_list_updater( + url: String, + refresh: std::time::Duration, + cache_path: std::path::PathBuf, + blocked: std::sync::Arc>>, +) { + tokio::spawn(async move { + loop { + match fetch_address_list(&url).await { + Ok(addresses) => { + if let Err(e) = std::fs::write( + &cache_path, + addresses.iter().cloned().collect::>().join("\n"), + ) { + tracing::warn!("Failed to write address cache: {e}"); + } + let count = addresses.len(); + let mut guard = blocked.write().await; + *guard = addresses; + tracing::info!("Updated blocked address list ({count} entries)"); + } + Err(e) => tracing::warn!("Failed to fetch address list: {e}"), + } + tokio::time::sleep(refresh).await; + } + }); +} + +async fn fetch_address_list(url: &str) -> anyhow::Result> { + let body = reqwest::get(url).await?.error_for_status()?.text().await?; + Ok(body + .lines() + .map(|l| normalize_blocked_address(l.trim())) + .filter(|l| !l.is_empty()) + .collect()) +} + +async fn fetch_geoip_db(dest: &Path) -> anyhow::Result<()> { + use std::io::Read; + + let now = chrono_month_year(); + let url = + format!("https://download.db-ip.com/free/dbip-country-lite-{}-{}.mmdb.gz", now.0, now.1); + tracing::info!("Fetching GeoIP database from {}", url); + + let response = reqwest::get(&url).await?; + if !response.status().is_success() { + anyhow::bail!("Failed to fetch GeoIP database: HTTP {}", response.status()); + } + let compressed = response.bytes().await?; + let mut decoder = flate2::read::GzDecoder::new(&compressed[..]); + let mut decompressed = Vec::new(); + decoder.read_to_end(&mut decompressed)?; + + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent)?; + } + std::fs::write(dest, &decompressed)?; + tracing::info!("GeoIP database saved to {}", dest.display()); + Ok(()) +} + +/// Returns (year, month) as strings for the DB-IP download URL. +fn chrono_month_year() -> (String, String) { + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .expect("system time should be after UNIX_EPOCH"); + let days_since_epoch = (now.as_secs() / 86_400) as i64; + let (year, month) = year_month_from_days_since_epoch(days_since_epoch); + (year.to_string(), format!("{month:02}")) +} + +fn year_month_from_days_since_epoch(days_since_epoch: i64) -> (i32, u32) { + // Exact conversion from Unix days to Gregorian year/month in UTC. + // Based on Howard Hinnant's civil calendar algorithm. + let z = days_since_epoch + 719_468; + let era = if z >= 0 { z } else { z - 146_096 } / 146_097; + let doe = z - era * 146_097; + let yoe = (doe - doe / 1_460 + doe / 36_524 - doe / 146_096) / 365; + let y = yoe + era * 400; + let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); + let mp = (5 * doy + 2) / 153; + let month = (mp + if mp < 10 { 3 } else { -9 }) as u32; + let year = (y + if month <= 2 { 1 } else { 0 }) as i32; + (year, month) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_geo_reader() -> maxminddb::Reader> { + // Keep a tiny MMDB fixture in-repo so GeoIP tests are offline and deterministic while + // still exercising real MaxMind DB parsing/lookup behavior. + maxminddb::Reader::open_readfile(concat!( + env!("CARGO_MANIFEST_DIR"), + "/test-data/GeoIP2-Country-Test.mmdb" + )) + .unwrap() + } + + #[test] + fn check_ip_allows_when_no_geo_reader() { + let ac = AccessControl { geo_reader: None, blocked_regions: HashSet::new() }; + assert!(ac.check_ip("1.2.3.4".parse().unwrap())); + } + + #[test] + fn check_ip_allows_when_no_blocked_regions() { + let reader = test_geo_reader(); + let ac = AccessControl { geo_reader: Some(reader), blocked_regions: HashSet::new() }; + assert!(ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_blocks_blocked_region() { + let reader = test_geo_reader(); + // 2.125.160.216 is GB in the test database + let blocked_regions: HashSet = ["GB"].iter().map(|s| s.to_string()).collect(); + let ac = AccessControl { geo_reader: Some(reader), blocked_regions }; + assert!(!ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_allows_non_blocked_region() { + let reader = test_geo_reader(); + // 2.125.160.216 is GB in the test database + let blocked_regions: HashSet = ["US"].iter().map(|s| s.to_string()).collect(); + let ac = AccessControl { geo_reader: Some(reader), blocked_regions }; + assert!(ac.check_ip("2.125.160.216".parse().unwrap())); + } + + #[test] + fn check_ip_fail_open_on_unknown_ip() { + let reader = test_geo_reader(); + let blocked_regions: HashSet = ["US"].iter().map(|s| s.to_string()).collect(); + let ac = AccessControl { geo_reader: Some(reader), blocked_regions }; + // 127.0.0.1 won't be in test DB + assert!(ac.check_ip("127.0.0.1".parse().unwrap())); + } + + #[test] + fn year_month_conversion_handles_leap_day() { + // 2024-02-29 00:00:00 UTC + let days = 19_782; + let (year, month) = year_month_from_days_since_epoch(days); + assert_eq!((year, month), (2024, 2)); + } + + #[test] + fn year_month_conversion_handles_year_start() { + // 2024-01-01 00:00:00 UTC + let days = 19_723; + let (year, month) = year_month_from_days_since_epoch(days); + assert_eq!((year, month), (2024, 1)); + } +} diff --git a/payjoin-mailroom/src/config.rs b/payjoin-mailroom/src/config.rs index ad8d949f1..9ad1fac3d 100644 --- a/payjoin-mailroom/src/config.rs +++ b/payjoin-mailroom/src/config.rs @@ -16,6 +16,8 @@ pub struct Config { pub telemetry: Option, #[cfg(feature = "acme")] pub acme: Option, + #[cfg(feature = "access-control")] + pub access_control: Option, } #[cfg(feature = "telemetry")] @@ -52,6 +54,17 @@ impl AcmeConfig { } } +#[cfg(feature = "access-control")] +#[derive(Debug, Clone, Deserialize, Default)] +#[serde(default)] +pub struct AccessControlConfig { + pub geo_db_path: Option, + pub blocked_regions: Vec, + pub blocked_addresses_path: Option, + pub blocked_addresses_url: Option, + pub blocked_addresses_refresh_secs: Option, + pub enable_v1: bool, +} impl Default for Config { fn default() -> Self { Self { @@ -62,6 +75,8 @@ impl Default for Config { telemetry: None, #[cfg(feature = "acme")] acme: None, + #[cfg(feature = "access-control")] + access_control: None, } } } @@ -84,6 +99,8 @@ impl Config { telemetry: None, #[cfg(feature = "acme")] acme: None, + #[cfg(feature = "access-control")] + access_control: None, } } @@ -101,6 +118,7 @@ impl Config { .list_separator(",") .with_list_parse_key("acme.domains") .with_list_parse_key("acme.contact") + .with_list_parse_key("access_control.blocked_regions") .try_parsing(true), ) .build()? diff --git a/payjoin-mailroom/src/lib.rs b/payjoin-mailroom/src/lib.rs index 52a157bcf..1bd94bce2 100644 --- a/payjoin-mailroom/src/lib.rs +++ b/payjoin-mailroom/src/lib.rs @@ -1,6 +1,10 @@ +#[cfg(feature = "access-control")] +use axum::extract::connect_info::Connected; use axum::extract::State; use axum::http::Method; use axum::response::{IntoResponse, Response}; +#[cfg(feature = "access-control")] +use axum::serve::IncomingStream; use axum::Router; use config::Config; use ohttp_relay::SentinelTag; @@ -10,6 +14,8 @@ use tokio_listener::{Listener, SystemOptions, UserOptions}; use tower::{Service, ServiceBuilder}; use tracing::info; +#[cfg(feature = "access-control")] +pub mod access_control; pub mod cli; pub mod config; pub mod metrics; @@ -23,18 +29,37 @@ struct Services { directory: payjoin_directory::Service, relay: ohttp_relay::Service, metrics: MetricsService, + #[cfg(feature = "access-control")] + access_control: Option>, } pub async fn serve(config: Config, meter_provider: Option) -> anyhow::Result<()> { let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let access_control = init_access_control(&config).await?; + #[cfg(feature = "access-control")] + let blocked_addresses = init_blocked_addresses(&config).await?; + #[cfg(not(feature = "access-control"))] + let blocked_addresses = None; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory: init_directory( + &config, + sentinel_tag, + blocked_addresses, + get_v1_disabled(&config), + ) + .await?, relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(meter_provider), + #[cfg(feature = "access-control")] + access_control, }; let app = build_app(services); + #[cfg(feature = "access-control")] + let app = app.into_make_service_with_connect_info::(); let listener = Listener::bind(&config.listener, &SystemOptions::default(), &UserOptions::default()) @@ -62,10 +87,23 @@ pub async fn serve_manual_tls( let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let blocked_addresses = init_blocked_addresses(&config).await?; + #[cfg(not(feature = "access-control"))] + let blocked_addresses = None; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory: init_directory( + &config, + sentinel_tag, + blocked_addresses, + get_v1_disabled(&config), + ) + .await?, relay: ohttp_relay::Service::new_with_roots(root_store, sentinel_tag).await, metrics: MetricsService::new(None), + #[cfg(feature = "access-control")] + access_control: init_access_control(&config).await?, }; let app = build_app(services); @@ -82,14 +120,18 @@ pub async fn serve_manual_tls( info!("Payjoin service listening on port {} with TLS", port); tokio::spawn(async move { axum_server::from_tcp_rustls(listener.into_std()?, tls)? - .serve(app.into_make_service()) + .serve(app.into_make_service_with_connect_info::()) .await .map_err(Into::into) }) } None => { info!("Payjoin service listening on port {} without TLS", port); - tokio::spawn(async move { axum::serve(listener, app).await.map_err(Into::into) }) + tokio::spawn(async move { + axum::serve(listener, app.into_make_service_with_connect_info::()) + .await + .map_err(Into::into) + }) } }; @@ -115,10 +157,23 @@ pub async fn serve_acme( let sentinel_tag = generate_sentinel_tag(); + #[cfg(feature = "access-control")] + let blocked_addresses = init_blocked_addresses(&config).await?; + #[cfg(not(feature = "access-control"))] + let blocked_addresses = None; + let services = Services { - directory: init_directory(&config, sentinel_tag).await?, + directory: init_directory( + &config, + sentinel_tag, + blocked_addresses, + get_v1_disabled(&config), + ) + .await?, relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(meter_provider), + #[cfg(feature = "access-control")] + access_control: init_access_control(&config).await?, }; let app = build_app(services); @@ -148,7 +203,10 @@ pub async fn serve_acme( }); info!("Payjoin service listening on {} with ACME TLS", addr); - axum_server::bind(addr).acceptor(acceptor).serve(app.into_make_service()).await?; + axum_server::bind(addr) + .acceptor(acceptor) + .serve(app.into_make_service_with_connect_info::()) + .await?; Ok(()) } @@ -157,9 +215,24 @@ pub async fn serve_acme( /// at detecting self loops. fn generate_sentinel_tag() -> SentinelTag { SentinelTag::new(rand::thread_rng().gen()) } +#[cfg(feature = "access-control")] +impl Connected> for middleware::MaybePeerIp { + fn connect_info(stream: IncomingStream<'_, Listener>) -> Self { + let ip = match stream.remote_addr() { + tokio_listener::SomeSocketAddr::Tcp(addr) => Some(addr.ip()), + _ => None, + }; + Self(ip) + } +} + async fn init_directory( config: &Config, sentinel_tag: SentinelTag, + blocked_addresses: Option< + std::sync::Arc>>, + >, + v1_disabled: bool, ) -> anyhow::Result> { let db = payjoin_directory::FilesDb::init(config.timeout, config.storage_dir.clone()).await?; db.spawn_background_prune().await; @@ -167,7 +240,131 @@ async fn init_directory( let ohttp_keys_dir = config.storage_dir.join("ohttp-keys"); let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?; - Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag)) + Ok(payjoin_directory::Service::new( + db, + ohttp_config.into(), + sentinel_tag, + blocked_addresses, + v1_disabled, + )) +} + +#[cfg(feature = "access-control")] +async fn init_access_control( + config: &Config, +) -> anyhow::Result>> { + match &config.access_control { + Some(ac_config) => { + let ac = + access_control::AccessControl::from_config(ac_config, &config.storage_dir).await?; + info!("Access control enabled"); + Ok(Some(std::sync::Arc::new(ac))) + } + None => Ok(None), + } +} + +#[cfg(feature = "access-control")] +async fn init_blocked_addresses( + config: &Config, +) -> anyhow::Result>>>> +{ + let ac_config = match &config.access_control { + Some(c) => c, + None => return Ok(None), + }; + + // Neither file nor URL configured + if ac_config.blocked_addresses_path.is_none() && ac_config.blocked_addresses_url.is_none() { + return Ok(None); + } + + // Load initial addresses from file if available + let mut addresses = std::collections::HashSet::new(); + if let Some(path) = &ac_config.blocked_addresses_path { + addresses = access_control::load_blocked_addresses(path)?; + info!("Loaded {} blocked addresses from {}", addresses.len(), path.display()); + } + + let blocked = std::sync::Arc::new(tokio::sync::RwLock::new(addresses)); + + // If URL configured, try initial fetch and spawn background updater + if let Some(url) = &ac_config.blocked_addresses_url { + let cache_path = config.storage_dir.join("blocked_addresses_cache.txt"); + let refresh = std::time::Duration::from_secs( + ac_config.blocked_addresses_refresh_secs.unwrap_or(86400), + ); + + // Try initial fetch; fall back to cache on failure + match reqwest::get(url).await { + Ok(resp) if resp.status().is_success() => match resp.text().await { + Ok(body) => { + let fetched: std::collections::HashSet = body + .lines() + .map(|l| access_control::normalize_blocked_address(l.trim())) + .filter(|l| !l.is_empty()) + .collect(); + if let Err(e) = std::fs::write( + &cache_path, + fetched.iter().cloned().collect::>().join("\n"), + ) { + tracing::warn!("Failed to write address cache: {e}"); + } + info!("Fetched {} blocked addresses from URL", fetched.len()); + *blocked.write().await = fetched; + } + Err(e) => { + tracing::warn!("Failed to read address list response: {e}"); + try_load_cache(&cache_path, &blocked).await; + } + }, + Ok(resp) => { + tracing::warn!("Failed to fetch address list: HTTP {}", resp.status()); + try_load_cache(&cache_path, &blocked).await; + } + Err(e) => { + tracing::warn!("Failed to fetch address list: {e}"); + try_load_cache(&cache_path, &blocked).await; + } + } + + access_control::spawn_address_list_updater( + url.clone(), + refresh, + cache_path, + blocked.clone(), + ); + } + + Ok(Some(blocked)) +} + +#[cfg(feature = "access-control")] +async fn try_load_cache( + cache_path: &std::path::Path, + blocked: &std::sync::Arc>>, +) { + if cache_path.exists() { + match access_control::load_blocked_addresses(cache_path) { + Ok(cached) => { + info!("Loaded {} blocked addresses from cache", cached.len()); + *blocked.write().await = cached; + } + Err(e) => tracing::warn!("Failed to load address cache: {e}"), + } + } +} + +fn get_v1_disabled(config: &Config) -> bool { + #[cfg(feature = "access-control")] + { + !config.access_control.as_ref().is_some_and(|ac| ac.enable_v1) + } + #[cfg(not(feature = "access-control"))] + { + let _ = config; + false + } } fn init_ohttp_config( @@ -186,14 +383,28 @@ fn init_ohttp_config( fn build_app(services: Services) -> Router { let metrics = services.metrics.clone(); - Router::new() + + #[cfg(feature = "access-control")] + let acaccess_control = services.access_control.clone(); + + #[allow(unused_mut)] + let mut router = Router::new() .fallback(route_request) .layer( ServiceBuilder::new() .layer(axum::middleware::from_fn_with_state(metrics.clone(), track_metrics)) .layer(axum::middleware::from_fn_with_state(metrics, track_connections)), ) - .with_state(services) + .with_state(services); + + #[cfg(feature = "access-control")] + { + router = router + .layer(axum::middleware::from_fn(middleware::check_access_control)) + .layer(axum::Extension(acaccess_control)); + } + + router } async fn route_request( @@ -284,7 +495,7 @@ mod tests { // Make a request through the relay that targets this same instance's directory. // The path format is /{gateway_url} where gateway_url points back to ourselves. - let ohttp_req_url = format!("{}/{}", base_url, base_url); + let ohttp_req_url = format!("{base_url}/{base_url}"); let response = client .post(&ohttp_req_url) @@ -322,7 +533,6 @@ mod tests { // Make a request through the relay instance to the directory instance. // Since they're different instances with different sentinel tags, this should work. let ohttp_req_url = format!("{}/{}", relay_url, directory_url); - let response = client .post(&ohttp_req_url) .header("Content-Type", "message/ohttp-req") @@ -358,9 +568,11 @@ mod tests { let sentinel_tag = generate_sentinel_tag(); let services = Services { - directory: init_directory(&config, sentinel_tag).await.unwrap(), + directory: init_directory(&config, sentinel_tag, None, false).await.unwrap(), relay: ohttp_relay::Service::new(sentinel_tag).await, metrics: MetricsService::new(Some(provider.clone())), + #[cfg(feature = "access-control")] + access_control: None, }; let app = build_app(services); @@ -382,4 +594,20 @@ mod tests { assert!(metric_names.contains(&TOTAL_CONNECTIONS), "missing total_connections"); assert!(metric_names.contains(&ACTIVE_CONNECTIONS), "missing active_connections"); } + + #[cfg(feature = "access-control")] + #[test] + fn v1_is_disabled_by_default() { + let config = Config::default(); + assert!(get_v1_disabled(&config)); + } + + #[cfg(feature = "access-control")] + #[test] + fn v1_can_be_enabled_explicitly() { + let mut config = Config::default(); + config.access_control = + Some(crate::config::AccessControlConfig { enable_v1: true, ..Default::default() }); + assert!(!get_v1_disabled(&config)); + } } diff --git a/payjoin-mailroom/src/middleware.rs b/payjoin-mailroom/src/middleware.rs index 875bc49b4..2d6943d9f 100644 --- a/payjoin-mailroom/src/middleware.rs +++ b/payjoin-mailroom/src/middleware.rs @@ -1,9 +1,18 @@ +#[cfg(feature = "access-control")] +use std::sync::Arc; + use axum::extract::Request; use axum::middleware::Next; use axum::response::Response; +#[cfg(feature = "access-control")] +use crate::access_control::AccessControl; use crate::metrics::MetricsService; +#[cfg(feature = "access-control")] +#[derive(Clone, Copy, Debug)] +pub struct MaybePeerIp(pub Option); + pub async fn track_metrics( metrics: axum::extract::State, req: Request, @@ -20,6 +29,35 @@ pub async fn track_metrics( response } +#[cfg(feature = "access-control")] +pub async fn check_access_control( + axum::extract::Extension(access_control): axum::extract::Extension>>, + req: Request, + next: Next, +) -> Response { + use axum::response::IntoResponse; + + if let Some(ac) = access_control.as_ref() { + let peer_ip = req + .extensions() + .get::>() + .map(|ci| ci.0.ip()) + .or_else(|| { + req.extensions().get::>().and_then(|ci| { + let maybe_peer_ip = ci.0; + maybe_peer_ip.0 + }) + }); + if let Some(ip) = peer_ip { + if !ac.check_ip(ip) { + tracing::warn!("Blocked request from {}", ip); + return (axum::http::StatusCode::FORBIDDEN, "").into_response(); + } + } + } + next.run(req).await +} + pub async fn track_connections( metrics: axum::extract::State, req: Request, diff --git a/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb b/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb new file mode 100644 index 000000000..840f89384 Binary files /dev/null and b/payjoin-mailroom/test-data/GeoIP2-Country-Test.mmdb differ diff --git a/payjoin-test-utils/Cargo.toml b/payjoin-test-utils/Cargo.toml index a153e4528..0b2fcff0d 100644 --- a/payjoin-test-utils/Cargo.toml +++ b/payjoin-test-utils/Cargo.toml @@ -8,6 +8,10 @@ repository = "https://github.com/payjoin/rust-payjoin" rust-version = "1.85" license = "MIT" +[features] +default = [] +access-control = ["payjoin-mailroom/access-control"] + [dependencies] axum-server = { version = "0.8", features = ["tls-rustls-no-provider"] } bitcoin = { version = "0.32.7", features = ["base64"] } diff --git a/payjoin-test-utils/src/lib.rs b/payjoin-test-utils/src/lib.rs index 1bccce351..4051d66c8 100644 --- a/payjoin-test-utils/src/lib.rs +++ b/payjoin-test-utils/src/lib.rs @@ -117,11 +117,23 @@ pub async fn init_directory( BoxSendSyncError, > { let tempdir = tempdir()?; - let config = payjoin_mailroom::config::Config::new( + let base_config = payjoin_mailroom::config::Config::new( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), ); + #[cfg(feature = "access-control")] + let config = { + let mut config = base_config; + // Test services exercise V1/V2 interoperability paths; keep V1 enabled explicitly. + config.access_control = Some(payjoin_mailroom::config::AccessControlConfig { + enable_v1: true, + ..Default::default() + }); + config + }; + #[cfg(not(feature = "access-control"))] + let config = base_config; let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?; @@ -144,11 +156,23 @@ async fn init_ohttp_relay( BoxSendSyncError, > { let tempdir = tempdir()?; - let config = payjoin_mailroom::config::Config::new( + let base_config = payjoin_mailroom::config::Config::new( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), ); + #[cfg(feature = "access-control")] + let config = { + let mut config = base_config; + // Keep relay-side directory fallback behavior aligned with integration expectations. + config.access_control = Some(payjoin_mailroom::config::AccessControlConfig { + enable_v1: true, + ..Default::default() + }); + config + }; + #[cfg(not(feature = "access-control"))] + let config = base_config; let (port, handle) = payjoin_mailroom::serve_manual_tls(config, None, root_store) .await diff --git a/payjoin/Cargo.toml b/payjoin/Cargo.toml index 61dc7bc58..7be371715 100644 --- a/payjoin/Cargo.toml +++ b/payjoin/Cargo.toml @@ -62,7 +62,7 @@ web-time = "1.1.0" [dev-dependencies] once_cell = "1.21.3" -payjoin-test-utils = { version = "0.0.1" } +payjoin-test-utils = { version = "0.0.1", features = ["access-control"] } tokio = { version = "1.47.1", features = ["full"] } tracing = "0.1.41"