diff --git a/payjoin-directory/src/config.rs b/payjoin-directory/src/config.rs index 42403def0..558a79b85 100644 --- a/payjoin-directory/src/config.rs +++ b/payjoin-directory/src/config.rs @@ -16,6 +16,8 @@ pub struct Config { pub timeout: Duration, pub storage_dir: PathBuf, pub ohttp_keys: PathBuf, // TODO OhttpConfig struct with rotation params, etc + #[serde(default)] + pub enable_v1: bool, #[cfg(feature = "acme")] pub acme: Option, } @@ -54,6 +56,7 @@ impl Config { timeout: Duration::from_secs(built_config.get("timeout")?), storage_dir: built_config.get("storage_dir")?, ohttp_keys: built_config.get("ohttp_keys")?, + enable_v1: built_config.get("enable_v1").unwrap_or(false), #[cfg(feature = "acme")] acme: if built_config.get_table("acme").is_ok() { Some(AcmeConfig { diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index d60f102d6..d295e7e5f 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -37,6 +37,8 @@ const V1_MAX_BUFFER_SIZE: usize = 65536; const V1_REJECT_RES_JSON: &str = r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#; const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#; +const V1_VERSION_UNSUPPORTED_RES_JSON: &str = + r#"{"errorCode": "version-unsupported", "supported": [2], "message": "V1 is not supported"}"#; pub(crate) mod db; @@ -68,6 +70,7 @@ pub struct Service { db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, + enable_v1: bool, } impl tower::Service> for Service @@ -91,8 +94,8 @@ where } impl Service { - pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag) -> Self { - Self { db, ohttp, sentinel_tag } + pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, enable_v1: bool) -> Self { + Self { db, ohttp, sentinel_tag, enable_v1 } } #[cfg(feature = "_manual-tls")] @@ -214,7 +217,7 @@ impl Service { self.handle_ohttp_gateway_get(&query).await, (Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await, (Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await, - (Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await, + (Method::POST, ["", id]) => self.handle_post_v1(id, query, body).await, (Method::GET, ["", "health"]) => health_check().await, (Method::GET, ["", ""]) => handle_directory_home_path().await, _ => Ok(not_found()), @@ -227,6 +230,28 @@ impl Service { Ok(response) } + /// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject. + async fn handle_post_v1( + &self, + id: &str, + query: String, + body: B, + ) -> Result>, HandlerError> + where + B: Body + Send + 'static, + B::Error: Into, + { + if self.enable_v1 { + self.post_fallback_v1(id, query, body).await + } else { + let _ = (id, query, body); + Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header(CONTENT_TYPE, "application/json") + .body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?) + } + } + /// Handle an encapsulated OHTTP request and return an encapsulated response async fn handle_ohttp_gateway( &self, @@ -304,7 +329,7 @@ impl Service { match (parts.method, path_segments.as_slice()) { (Method::POST, &["", id]) => self.post_mailbox(id, body).await, (Method::GET, &["", id]) => self.get_mailbox(id).await, - (Method::PUT, &["", id]) => self.put_payjoin_v1(id, body).await, + (Method::PUT, &["", id]) if self.enable_v1 => self.put_payjoin_v1(id, body).await, _ => Ok(not_found()), } } @@ -603,3 +628,87 @@ fn empty() -> BoxBody { fn full>(chunk: T) -> BoxBody { Full::new(chunk.into()).map_err(|never| match never {}).boxed() } + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use http_body_util::BodyExt; + use hyper::body::Bytes; + use hyper::{Method, Request, StatusCode}; + use ohttp_relay::SentinelTag; + use payjoin::directory::ShortId; + + use super::*; + + async fn test_service(enable_v1: bool) -> Service { + let dir = tempfile::tempdir().expect("tempdir"); + let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init"); + let ohttp: ohttp::Server = + key_config::gen_ohttp_server_config().expect("ohttp config").into(); + Service::new(db, ohttp, SentinelTag::new([0u8; 32]), enable_v1) + } + + /// A valid ShortId encoded as bech32 for use in URL paths. + fn valid_short_id_path() -> String { + let id = ShortId([0u8; 8]); + id.to_string() + } + + async fn collect_body(res: Response>) -> (StatusCode, String) { + let (parts, body) = res.into_parts(); + let bytes = body.collect().await.unwrap().to_bytes(); + (parts.status, String::from_utf8(bytes.to_vec()).unwrap()) + } + + #[tokio::test] + async fn post_v1_when_disabled_returns_version_unsupported() { + let mut svc = test_service(false).await; + let id = valid_short_id_path(); + let req = Request::builder() + .method(Method::POST) + .uri(format!("http://localhost/{id}")) + .body(Full::new(Bytes::from("base64-psbt"))) + .unwrap(); + + let res = tower::Service::call(&mut svc, req).await.unwrap(); + let (status, body) = collect_body(res).await; + + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(body, V1_VERSION_UNSUPPORTED_RES_JSON); + } + + #[tokio::test] + async fn post_v1_with_invalid_body_returns_reject() { + let mut svc = test_service(true).await; + let id = valid_short_id_path(); + let req = Request::builder() + .method(Method::POST) + .uri(format!("http://localhost/{id}")) + .body(Full::new(Bytes::from(vec![0xFF, 0xFE]))) + .unwrap(); + + let res = tower::Service::call(&mut svc, req).await.unwrap(); + let (status, body) = collect_body(res).await; + + assert_eq!(status, StatusCode::BAD_REQUEST); + assert_eq!(body, V1_REJECT_RES_JSON); + } + + #[tokio::test] + async fn post_v1_with_no_receiver_returns_unavailable() { + let mut svc = test_service(true).await; + let id = valid_short_id_path(); + let req = Request::builder() + .method(Method::POST) + .uri(format!("http://localhost/{id}")) + .body(Full::new(Bytes::from("base64-psbt"))) + .unwrap(); + + let res = tower::Service::call(&mut svc, req).await.unwrap(); + let (status, body) = collect_body(res).await; + + assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); + assert_eq!(body, V1_UNAVAILABLE_RES_JSON); + } +} diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 9c68b7831..e60483f3d 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -29,7 +29,7 @@ async fn main() -> Result<(), BoxError> { .await .expect("Failed to initialize persistent storage"); - let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32])); + let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), config.enable_v1); let listener = TcpListener::bind(config.listen_addr).await?; diff --git a/payjoin-mailroom/src/config.rs b/payjoin-mailroom/src/config.rs index ad8d949f1..813f22a04 100644 --- a/payjoin-mailroom/src/config.rs +++ b/payjoin-mailroom/src/config.rs @@ -12,6 +12,7 @@ pub struct Config { pub storage_dir: PathBuf, #[serde(deserialize_with = "deserialize_duration_secs")] pub timeout: Duration, + pub enable_v1: bool, #[cfg(feature = "telemetry")] pub telemetry: Option, #[cfg(feature = "acme")] @@ -58,6 +59,7 @@ impl Default for Config { listener: "[::]:8080".parse().expect("valid default listener address"), storage_dir: PathBuf::from("./data"), timeout: Duration::from_secs(30), + enable_v1: false, #[cfg(feature = "telemetry")] telemetry: None, #[cfg(feature = "acme")] @@ -75,11 +77,17 @@ where } impl Config { - pub fn new(listener: ListenerAddress, storage_dir: PathBuf, timeout: Duration) -> Self { + pub fn new( + listener: ListenerAddress, + storage_dir: PathBuf, + timeout: Duration, + enable_v1: bool, + ) -> Self { Self { listener, storage_dir, timeout, + enable_v1, #[cfg(feature = "telemetry")] telemetry: None, #[cfg(feature = "acme")] diff --git a/payjoin-mailroom/src/lib.rs b/payjoin-mailroom/src/lib.rs index 52a157bcf..a95698209 100644 --- a/payjoin-mailroom/src/lib.rs +++ b/payjoin-mailroom/src/lib.rs @@ -167,7 +167,7 @@ async fn init_directory( let ohttp_keys_dir = config.storage_dir.join("ohttp-keys"); let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?; - Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag)) + Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag, config.enable_v1)) } fn init_ohttp_config( @@ -260,6 +260,7 @@ mod tests { "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), + false, ); let mut root_store = RootCertStore::empty(); @@ -284,7 +285,7 @@ mod tests { // Make a request through the relay that targets this same instance's directory. // The path format is /{gateway_url} where gateway_url points back to ourselves. - let ohttp_req_url = format!("{}/{}", base_url, base_url); + let ohttp_req_url = format!("{base_url}/{base_url}"); let response = client .post(&ohttp_req_url) @@ -354,6 +355,7 @@ mod tests { "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), + false, ); let sentinel_tag = generate_sentinel_tag(); diff --git a/payjoin-test-utils/src/lib.rs b/payjoin-test-utils/src/lib.rs index 1bccce351..e3b8fcece 100644 --- a/payjoin-test-utils/src/lib.rs +++ b/payjoin-test-utils/src/lib.rs @@ -121,6 +121,7 @@ pub async fn init_directory( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), + true, ); let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?; @@ -148,6 +149,7 @@ async fn init_ohttp_relay( "[::]:0".parse().expect("valid listener address"), tempdir.path().to_path_buf(), Duration::from_secs(2), + false, ); let (port, handle) = payjoin_mailroom::serve_manual_tls(config, None, root_store)