Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions payjoin-directory/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub struct Config {
pub timeout: Duration,
pub storage_dir: PathBuf,
pub ohttp_keys: PathBuf, // TODO OhttpConfig struct with rotation params, etc
#[serde(default)]
pub enable_v1: bool,
#[cfg(feature = "acme")]
pub acme: Option<AcmeConfig>,
}
Expand Down Expand Up @@ -54,6 +56,7 @@ impl Config {
timeout: Duration::from_secs(built_config.get("timeout")?),
storage_dir: built_config.get("storage_dir")?,
ohttp_keys: built_config.get("ohttp_keys")?,
enable_v1: built_config.get("enable_v1").unwrap_or(false),
#[cfg(feature = "acme")]
acme: if built_config.get_table("acme").is_ok() {
Some(AcmeConfig {
Expand Down
117 changes: 113 additions & 4 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ const V1_MAX_BUFFER_SIZE: usize = 65536;
const V1_REJECT_RES_JSON: &str =
r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#;
const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#;
const V1_VERSION_UNSUPPORTED_RES_JSON: &str =
r#"{"errorCode": "version-unsupported", "supported": [2], "message": "V1 is not supported"}"#;

pub(crate) mod db;

Expand Down Expand Up @@ -68,6 +70,7 @@ pub struct Service<D: Db> {
db: D,
ohttp: ohttp::Server,
sentinel_tag: SentinelTag,
enable_v1: bool,
}

impl<D: Db, B> tower::Service<Request<B>> for Service<D>
Expand All @@ -91,8 +94,8 @@ where
}

impl<D: Db> Service<D> {
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag) -> Self {
Self { db, ohttp, sentinel_tag }
pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, enable_v1: bool) -> Self {
Self { db, ohttp, sentinel_tag, enable_v1 }
}

#[cfg(feature = "_manual-tls")]
Expand Down Expand Up @@ -214,7 +217,7 @@ impl<D: Db> Service<D> {
self.handle_ohttp_gateway_get(&query).await,
(Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await,
(Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await,
(Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await,
(Method::POST, ["", id]) => self.handle_post_v1(id, query, body).await,
(Method::GET, ["", "health"]) => health_check().await,
(Method::GET, ["", ""]) => handle_directory_home_path().await,
_ => Ok(not_found()),
Expand All @@ -227,6 +230,28 @@ impl<D: Db> Service<D> {
Ok(response)
}

/// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject.
async fn handle_post_v1<B>(
&self,
id: &str,
query: String,
body: B,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError>
where
B: Body<Data = Bytes> + Send + 'static,
B::Error: Into<BoxError>,
{
if self.enable_v1 {
self.post_fallback_v1(id, query, body).await
} else {
let _ = (id, query, body);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
let _ = (id, query, body);

Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header(CONTENT_TYPE, "application/json")
.body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?)
}
}

/// Handle an encapsulated OHTTP request and return an encapsulated response
async fn handle_ohttp_gateway<B>(
&self,
Expand Down Expand Up @@ -304,7 +329,7 @@ impl<D: Db> Service<D> {
match (parts.method, path_segments.as_slice()) {
(Method::POST, &["", id]) => self.post_mailbox(id, body).await,
(Method::GET, &["", id]) => self.get_mailbox(id).await,
(Method::PUT, &["", id]) => self.put_payjoin_v1(id, body).await,
(Method::PUT, &["", id]) if self.enable_v1 => self.put_payjoin_v1(id, body).await,
_ => Ok(not_found()),
}
}
Expand Down Expand Up @@ -603,3 +628,87 @@ fn empty() -> BoxBody<Bytes, hyper::Error> {
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
Full::new(chunk.into()).map_err(|never| match never {}).boxed()
}

#[cfg(test)]
mod tests {
use std::time::Duration;

use http_body_util::BodyExt;
use hyper::body::Bytes;
use hyper::{Method, Request, StatusCode};
use ohttp_relay::SentinelTag;
use payjoin::directory::ShortId;

use super::*;

async fn test_service(enable_v1: bool) -> Service<FilesDb> {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this have any value to be pulled out into payjoin-test-utils to be used more generally?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory we could move these tests to the payjoin-mailroom which already has a similar utility, but for now I think it's inoffensive enough until payjoin-directory gets folded in

let dir = tempfile::tempdir().expect("tempdir");
let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init");
let ohttp: ohttp::Server =
key_config::gen_ohttp_server_config().expect("ohttp config").into();
Service::new(db, ohttp, SentinelTag::new([0u8; 32]), enable_v1)
}

/// A valid ShortId encoded as bech32 for use in URL paths.
fn valid_short_id_path() -> String {
let id = ShortId([0u8; 8]);
id.to_string()
}

async fn collect_body(res: Response<BoxBody<Bytes, hyper::Error>>) -> (StatusCode, String) {
let (parts, body) = res.into_parts();
let bytes = body.collect().await.unwrap().to_bytes();
(parts.status, String::from_utf8(bytes.to_vec()).unwrap())
}

#[tokio::test]
async fn post_v1_when_disabled_returns_version_unsupported() {
let mut svc = test_service(false).await;
let id = valid_short_id_path();
let req = Request::builder()
.method(Method::POST)
.uri(format!("http://localhost/{id}"))
.body(Full::new(Bytes::from("base64-psbt")))
.unwrap();

let res = tower::Service::call(&mut svc, req).await.unwrap();
let (status, body) = collect_body(res).await;

assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(body, V1_VERSION_UNSUPPORTED_RES_JSON);
}

#[tokio::test]
async fn post_v1_with_invalid_body_returns_reject() {
let mut svc = test_service(true).await;
let id = valid_short_id_path();
let req = Request::builder()
.method(Method::POST)
.uri(format!("http://localhost/{id}"))
.body(Full::new(Bytes::from(vec![0xFF, 0xFE])))
.unwrap();

let res = tower::Service::call(&mut svc, req).await.unwrap();
let (status, body) = collect_body(res).await;

assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(body, V1_REJECT_RES_JSON);
}

#[tokio::test]
async fn post_v1_with_no_receiver_returns_unavailable() {
let mut svc = test_service(true).await;
let id = valid_short_id_path();
let req = Request::builder()
.method(Method::POST)
.uri(format!("http://localhost/{id}"))
.body(Full::new(Bytes::from("base64-psbt")))
.unwrap();

let res = tower::Service::call(&mut svc, req).await.unwrap();
let (status, body) = collect_body(res).await;

assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE);
assert_eq!(body, V1_UNAVAILABLE_RES_JSON);
}
}
2 changes: 1 addition & 1 deletion payjoin-directory/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ async fn main() -> Result<(), BoxError> {
.await
.expect("Failed to initialize persistent storage");

let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]));
let service = Service::new(db, ohttp.into(), SentinelTag::new([0u8; 32]), config.enable_v1);

let listener = TcpListener::bind(config.listen_addr).await?;

Expand Down
10 changes: 9 additions & 1 deletion payjoin-mailroom/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub struct Config {
pub storage_dir: PathBuf,
#[serde(deserialize_with = "deserialize_duration_secs")]
pub timeout: Duration,
pub enable_v1: bool,
#[cfg(feature = "telemetry")]
pub telemetry: Option<TelemetryConfig>,
#[cfg(feature = "acme")]
Expand Down Expand Up @@ -58,6 +59,7 @@ impl Default for Config {
listener: "[::]:8080".parse().expect("valid default listener address"),
storage_dir: PathBuf::from("./data"),
timeout: Duration::from_secs(30),
enable_v1: false,
#[cfg(feature = "telemetry")]
telemetry: None,
#[cfg(feature = "acme")]
Expand All @@ -75,11 +77,17 @@ where
}

impl Config {
pub fn new(listener: ListenerAddress, storage_dir: PathBuf, timeout: Duration) -> Self {
pub fn new(
listener: ListenerAddress,
storage_dir: PathBuf,
timeout: Duration,
enable_v1: bool,
) -> Self {
Self {
listener,
storage_dir,
timeout,
enable_v1,
#[cfg(feature = "telemetry")]
telemetry: None,
#[cfg(feature = "acme")]
Expand Down
6 changes: 4 additions & 2 deletions payjoin-mailroom/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ async fn init_directory(
let ohttp_keys_dir = config.storage_dir.join("ohttp-keys");
let ohttp_config = init_ohttp_config(&ohttp_keys_dir)?;

Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag))
Ok(payjoin_directory::Service::new(db, ohttp_config.into(), sentinel_tag, config.enable_v1))
}

fn init_ohttp_config(
Expand Down Expand Up @@ -260,6 +260,7 @@ mod tests {
"[::]:0".parse().expect("valid listener address"),
tempdir.path().to_path_buf(),
Duration::from_secs(2),
false,
);

let mut root_store = RootCertStore::empty();
Expand All @@ -284,7 +285,7 @@ mod tests {

// Make a request through the relay that targets this same instance's directory.
// The path format is /{gateway_url} where gateway_url points back to ourselves.
let ohttp_req_url = format!("{}/{}", base_url, base_url);
let ohttp_req_url = format!("{base_url}/{base_url}");

let response = client
.post(&ohttp_req_url)
Expand Down Expand Up @@ -354,6 +355,7 @@ mod tests {
"[::]:0".parse().expect("valid listener address"),
tempdir.path().to_path_buf(),
Duration::from_secs(2),
false,
);

let sentinel_tag = generate_sentinel_tag();
Expand Down
2 changes: 2 additions & 0 deletions payjoin-test-utils/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ pub async fn init_directory(
"[::]:0".parse().expect("valid listener address"),
tempdir.path().to_path_buf(),
Duration::from_secs(2),
true,
);

let tls_config = RustlsConfig::from_der(vec![local_cert_key.0], local_cert_key.1).await?;
Expand Down Expand Up @@ -148,6 +149,7 @@ async fn init_ohttp_relay(
"[::]:0".parse().expect("valid listener address"),
tempdir.path().to_path_buf(),
Duration::from_secs(2),
false,
);

let (port, handle) = payjoin_mailroom::serve_manual_tls(config, None, root_store)
Expand Down
Loading