From c173a25ebfc73318a91edea0402e2321bfa0e27c Mon Sep 17 00:00:00 2001 From: Yuval Kogman Date: Sat, 2 Aug 2025 23:44:53 +0200 Subject: [PATCH] refactor: introduce payjoin_directory::Service This is a preparatory change to avoid tight coupling to redis as a storage backend, and to enable ACME support by simplifying the listener logic. - implements hyper::Service trait - Arc> -> ohttp::Server - remove serve_payjoin_directory function - simplify and remove various listen helpers --- payjoin-directory/src/db.rs | 6 +- payjoin-directory/src/lib.rs | 622 ++++++++++++++++------------------ payjoin-directory/src/main.rs | 14 +- payjoin-test-utils/src/lib.rs | 26 +- 4 files changed, 324 insertions(+), 344 deletions(-) diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index b6f2c4bc1..6dfddcd62 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -9,14 +9,14 @@ const DEFAULT_COLUMN: &str = ""; const PJ_V1_COLUMN: &str = "pjv1"; #[derive(Debug, Clone)] -pub(crate) struct DbPool { +pub struct DbPool { client: Client, timeout: Duration, } /// Errors pertaining to [`DbPool`] #[derive(Debug)] -pub(crate) enum Error { +pub enum Error { Redis(RedisError), Timeout(tokio::time::error::Elapsed), } @@ -53,7 +53,6 @@ impl DbPool { Ok(Self { client, timeout }) } - /// Peek using [`DEFAULT_COLUMN`] as the channel type. pub async fn push_default(&self, mailbox_id: &ShortId, data: Vec) -> Result<()> { self.push(mailbox_id, DEFAULT_COLUMN, data).await } @@ -66,7 +65,6 @@ impl DbPool { self.push(mailbox_id, PJ_V1_COLUMN, data).await } - /// Peek using [`PJ_V1_COLUMN`] as the channel type. pub async fn peek_v1(&self, mailbox_id: &ShortId) -> Result> { self.peek_with_timeout(mailbox_id, PJ_V1_COLUMN).await } diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index bc8df2672..0dc3577c0 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -1,7 +1,5 @@ -use std::net::{IpAddr, Ipv6Addr, SocketAddr}; +use std::pin::Pin; use std::str::FromStr; -use std::sync::Arc; -use std::time::Duration; use anyhow::Result; use http_body_util::combinators::BoxBody; @@ -9,15 +7,12 @@ use http_body_util::{BodyExt, Empty, Full}; use hyper::body::{Body, Bytes, Incoming}; use hyper::header::{HeaderValue, ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_TYPE}; use hyper::server::conn::http1; -use hyper::service::service_fn; use hyper::{Method, Request, Response, StatusCode, Uri}; use hyper_util::rt::TokioIo; use payjoin::directory::{ShortId, ShortIdError, ENCAPSULATED_MESSAGE_BYTES}; -use tokio::net::TcpListener; -use tokio::sync::Mutex; use tracing::{debug, error, trace, warn}; -use crate::db::DbPool; +pub use crate::db::DbPool; pub mod key_config; pub use crate::key_config::*; @@ -37,42 +32,61 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message" mod db; -#[cfg(feature = "_danger-local-https")] -type BoxError = Box; +pub type BoxError = Box; #[cfg(feature = "_danger-local-https")] -pub async fn listen_tcp_with_tls_on_free_port( - db_host: String, - timeout: Duration, - cert_key: (Vec, Vec), - ohttp: ohttp::Server, -) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { - let listener = tokio::net::TcpListener::bind("[::]:0").await?; - let port = listener.local_addr()?.port(); - println!("Directory server binding to port {}", listener.local_addr()?); - let handle = - listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key, ohttp).await?; - Ok((port, handle)) +fn init_tls_acceptor(cert_key: (Vec, Vec)) -> Result { + use rustls::pki_types::{CertificateDer, PrivateKeyDer}; + use rustls::ServerConfig; + use tokio_rustls::TlsAcceptor; + let (cert, key) = cert_key; + let cert = CertificateDer::from(cert); + let key = + PrivateKeyDer::try_from(key).map_err(|e| anyhow::anyhow!("Could not parse key: {}", e))?; + + let mut server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .map_err(|e| anyhow::anyhow!("TLS error: {}", e))?; + server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; + Ok(TlsAcceptor::from(std::sync::Arc::new(server_config))) } -// Helper function to avoid code duplication -#[cfg(feature = "_danger-local-https")] -async fn listen_tcp_with_tls_on_listener( - listener: tokio::net::TcpListener, - db_host: String, - timeout: Duration, - tls_config: (Vec, Vec), +#[derive(Clone)] +pub struct Service { + pool: DbPool, ohttp: ohttp::Server, -) -> Result>, BoxError> { - let pool = DbPool::new(timeout, db_host).await?; - let ohttp = Arc::new(Mutex::new(ohttp)); - let tls_acceptor = init_tls_acceptor(tls_config)?; - // Spawn the connection handling loop in a separate task - let handle = tokio::spawn(async move { +} + +impl hyper::service::Service> for Service { + type Response = Response>; + type Error = anyhow::Error; + type Future = + Pin> + Send>>; + + fn call(&self, req: Request) -> Self::Future { + let pool = self.pool.clone(); + let ohttp = self.ohttp.clone(); + let this = Service::new(pool, ohttp); + Box::pin(async move { this.serve_request(req).await }) + } +} + +impl Service { + pub fn new(pool: DbPool, ohttp: ohttp::Server) -> Self { Self { pool, ohttp } } + + #[cfg(feature = "_danger-local-https")] + pub async fn serve_tls( + self, + listener: tokio::net::TcpListener, + tls_config: (Vec, Vec), + ) -> Result<(), BoxError> { + let tls_acceptor = init_tls_acceptor(tls_config)?; + // Spawn the connection handling loop in a separate task + while let Ok((stream, _)) = listener.accept().await { - let pool = pool.clone(); - let ohttp = ohttp.clone(); let tls_acceptor = tls_acceptor.clone(); + let service = self.clone(); tokio::spawn(async move { let tls_stream = match tls_acceptor.accept(stream).await { Ok(tls_stream) => tls_stream, @@ -82,12 +96,7 @@ async fn listen_tcp_with_tls_on_listener( } }; if let Err(err) = http1::Builder::new() - .serve_connection( - TokioIo::new(tls_stream), - service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone()) - }), - ) + .serve_connection(TokioIo::new(tls_stream), service) .with_upgrades() .await { @@ -96,171 +105,265 @@ async fn listen_tcp_with_tls_on_listener( }); } Ok(()) - }); - Ok(handle) -} + } -// Modify existing listen_tcp_with_tls to use the new helper -pub async fn listen_tcp( - port: u16, - db_host: String, - timeout: Duration, - ohttp: ohttp::Server, -) -> Result<(), Box> { - let pool = DbPool::new(timeout, db_host).await?; - let ohttp = Arc::new(Mutex::new(ohttp)); - let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); - let listener = TcpListener::bind(bind_addr).await?; - while let Ok((stream, _)) = listener.accept().await { - let pool = pool.clone(); - let ohttp = ohttp.clone(); - let io = TokioIo::new(stream); - tokio::spawn(async move { - if let Err(err) = http1::Builder::new() - .serve_connection( - io, - service_fn(move |req| { - serve_payjoin_directory(req, pool.clone(), ohttp.clone()) - }), - ) - .with_upgrades() - .await - { - error!("Error serving connection: {:?}", err); - } - }); + pub async fn serve_tcp(self, listener: tokio::net::TcpListener) -> Result<(), BoxError> { + while let Ok((stream, _)) = listener.accept().await { + let io = TokioIo::new(stream); + let service = self.clone(); + tokio::spawn(async move { + if let Err(err) = + http1::Builder::new().serve_connection(io, service).with_upgrades().await + { + error!("Error serving connection: {:?}", err); + } + }); + } + + Ok(()) } - Ok(()) -} + async fn serve_request( + &self, + req: Request, + ) -> Result>> { + let path = req.uri().path().to_string(); + let query = req.uri().query().unwrap_or_default().to_string(); + let (parts, body) = req.into_parts(); + + let path_segments: Vec<&str> = path.split('/').collect(); + debug!("Service::serve_request: {:?}", &path_segments); + let mut response = match (parts.method, path_segments.as_slice()) { + (Method::POST, ["", ".well-known", "ohttp-gateway"]) => + self.handle_ohttp_gateway(body).await, + (Method::GET, ["", ".well-known", "ohttp-gateway"]) => + self.handle_ohttp_gateway_get(&query).await, + (Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await, + (Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await, + (Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await, + (Method::GET, ["", "health"]) => health_check().await, + (Method::GET, ["", ""]) => handle_directory_home_path().await, + _ => Ok(not_found()), + } + .unwrap_or_else(|e| e.to_response()); -#[cfg(feature = "_danger-local-https")] -pub async fn listen_tcp_with_tls( - port: u16, - db_host: String, - timeout: Duration, - cert_key: (Vec, Vec), - ohttp: ohttp::Server, -) -> Result>, BoxError> { - let addr = format!("0.0.0.0:{port}"); - let listener = tokio::net::TcpListener::bind(&addr).await?; - listen_tcp_with_tls_on_listener(listener, db_host, timeout, cert_key, ohttp).await -} + // Allow CORS for third-party access + response.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); -#[cfg(feature = "_danger-local-https")] -fn init_tls_acceptor(cert_key: (Vec, Vec)) -> Result { - use rustls::pki_types::{CertificateDer, PrivateKeyDer}; - use rustls::ServerConfig; - use tokio_rustls::TlsAcceptor; - let (cert, key) = cert_key; - let cert = CertificateDer::from(cert); - let key = - PrivateKeyDer::try_from(key).map_err(|e| anyhow::anyhow!("Could not parse key: {}", e))?; + Ok(response) + } - let mut server_config = ServerConfig::builder() - .with_no_client_auth() - .with_single_cert(vec![cert], key) - .map_err(|e| anyhow::anyhow!("TLS error: {}", e))?; - server_config.alpn_protocols = vec![b"h2".to_vec(), b"http/1.1".to_vec(), b"http/1.0".to_vec()]; - Ok(TlsAcceptor::from(Arc::new(server_config))) -} + async fn handle_ohttp_gateway( + &self, + body: Incoming, + ) -> Result>, HandlerError> { + // decapsulate + let ohttp_body = + body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes(); + let (bhttp_req, res_ctx) = self + .ohttp + .decapsulate(&ohttp_body) + .map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?; + let mut cursor = std::io::Cursor::new(bhttp_req); + let req = bhttp::Message::read_bhttp(&mut cursor) + .map_err(|e| HandlerError::BadRequest(e.into()))?; + let uri = Uri::builder() + .scheme(req.control().scheme().unwrap_or_default()) + .authority(req.control().authority().unwrap_or_default()) + .path_and_query(req.control().path().unwrap_or_default()) + .build()?; + let body = req.content().to_vec(); + let mut http_req = + Request::builder().uri(uri).method(req.control().method().unwrap_or_default()); + for header in req.header().fields() { + http_req = http_req.header(header.name(), header.value()) + } + let request = http_req.body(full(body))?; -async fn serve_payjoin_directory( - req: Request, - pool: DbPool, - ohttp: Arc>, -) -> Result>> { - let path = req.uri().path().to_string(); - let query = req.uri().query().unwrap_or_default().to_string(); - let (parts, body) = req.into_parts(); - - let path_segments: Vec<&str> = path.split('/').collect(); - debug!("serve_payjoin_directory: {:?}", &path_segments); - let mut response = match (parts.method, path_segments.as_slice()) { - (Method::POST, ["", ".well-known", "ohttp-gateway"]) => - handle_ohttp_gateway(body, pool, ohttp).await, - (Method::GET, ["", ".well-known", "ohttp-gateway"]) => - handle_ohttp_gateway_get(&ohttp, &query).await, - (Method::POST, ["", ""]) => handle_ohttp_gateway(body, pool, ohttp).await, - (Method::GET, ["", "ohttp-keys"]) => get_ohttp_keys(&ohttp).await, - (Method::POST, ["", id]) => post_fallback_v1(id, query, body, pool).await, - (Method::GET, ["", "health"]) => health_check().await, - (Method::GET, ["", ""]) => handle_directory_home_path().await, - _ => Ok(not_found()), + let response = self.handle_v2(request).await?; + + let (parts, body) = response.into_parts(); + let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); + for (name, value) in parts.headers.iter() { + bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default()); + } + let full_body = body + .collect() + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))? + .to_bytes(); + bhttp_res.write_content(&full_body); + let mut bhttp_bytes = Vec::new(); + bhttp_res + .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes) + .map_err(|e| HandlerError::InternalServerError(e.into()))?; + bhttp_bytes.resize(BHTTP_REQ_BYTES, 0); + let ohttp_res = res_ctx + .encapsulate(&bhttp_bytes) + .map_err(|e| HandlerError::InternalServerError(e.into()))?; + assert!(ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES, "Unexpected OHTTP response size"); + Ok(Response::new(full(ohttp_res))) } - .unwrap_or_else(|e| e.to_response()); - // Allow CORS for third-party access - response.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + async fn handle_v2( + &self, + req: Request>, + ) -> Result>, HandlerError> { + let path = req.uri().path().to_string(); + let (parts, body) = req.into_parts(); + + let path_segments: Vec<&str> = path.split('/').collect(); + debug!("handle_v2: {:?}", &path_segments); + match (parts.method, path_segments.as_slice()) { + (Method::POST, &["", id]) => self.post_mailbox(id, body).await, + (Method::GET, &["", id]) => self.get_mailbox(id).await, + (Method::PUT, &["", id]) => self.put_payjoin_v1(id, body).await, + _ => Ok(not_found()), + } + } - Ok(response) -} + async fn post_mailbox( + &self, + id: &str, + body: BoxBody, + ) -> Result>, HandlerError> { + let none_response = Response::builder().status(StatusCode::OK).body(empty())?; + trace!("post_mailbox"); + + let id = ShortId::from_str(id)?; + + let req = body + .collect() + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))? + .to_bytes(); + if req.len() > V1_MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); + } -async fn handle_ohttp_gateway( - body: Incoming, - pool: DbPool, - ohttp: Arc>, -) -> Result>, HandlerError> { - // decapsulate - let ohttp_body = - body.collect().await.map_err(|e| HandlerError::BadRequest(e.into()))?.to_bytes(); - let ohttp_locked = ohttp.lock().await; - let (bhttp_req, res_ctx) = ohttp_locked - .decapsulate(&ohttp_body) - .map_err(|e| HandlerError::OhttpKeyRejection(e.into()))?; - drop(ohttp_locked); - let mut cursor = std::io::Cursor::new(bhttp_req); - let req = - bhttp::Message::read_bhttp(&mut cursor).map_err(|e| HandlerError::BadRequest(e.into()))?; - let uri = Uri::builder() - .scheme(req.control().scheme().unwrap_or_default()) - .authority(req.control().authority().unwrap_or_default()) - .path_and_query(req.control().path().unwrap_or_default()) - .build()?; - let body = req.content().to_vec(); - let mut http_req = - Request::builder().uri(uri).method(req.control().method().unwrap_or_default()); - for header in req.header().fields() { - http_req = http_req.header(header.name(), header.value()) + match self.pool.push_default(&id, req.into()).await { + Ok(_) => Ok(none_response), + Err(e) => Err(HandlerError::InternalServerError(e.into())), + } } - let request = http_req.body(full(body))?; - let response = handle_v2(pool, request).await?; + async fn get_mailbox( + &self, + id: &str, + ) -> Result>, HandlerError> { + trace!("get_mailbox"); + let id = ShortId::from_str(id)?; + let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; + handle_peek(self.pool.peek_default(&id).await, timeout_response) + } + async fn put_payjoin_v1( + &self, + id: &str, + body: BoxBody, + ) -> Result>, HandlerError> { + trace!("Put_payjoin_v1"); + let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; + + let id = ShortId::from_str(id)?; + let req = body + .collect() + .await + .map_err(|e| HandlerError::InternalServerError(e.into()))? + .to_bytes(); + if req.len() > V1_MAX_BUFFER_SIZE { + return Err(HandlerError::PayloadTooLarge); + } - let (parts, body) = response.into_parts(); - let mut bhttp_res = bhttp::Message::response(parts.status.as_u16()); - for (name, value) in parts.headers.iter() { - bhttp_res.put_header(name.as_str(), value.to_str().unwrap_or_default()); + match self.pool.push_v1(&id, req.into()).await { + Ok(_) => Ok(ok_response), + Err(e) => Err(HandlerError::BadRequest(e.into())), + } + } + + async fn post_fallback_v1( + &self, + id: &str, + query: String, + body: impl Body, + ) -> Result>, HandlerError> { + trace!("Post fallback v1"); + let none_response = Response::builder() + .status(StatusCode::SERVICE_UNAVAILABLE) + .body(full(V1_UNAVAILABLE_RES_JSON))?; + let bad_request_body_res = + Response::builder().status(StatusCode::BAD_REQUEST).body(full(V1_REJECT_RES_JSON))?; + + let body_bytes = match body.collect().await { + Ok(bytes) => bytes.to_bytes(), + Err(_) => return Ok(bad_request_body_res), + }; + + let body_str = match String::from_utf8(body_bytes.to_vec()) { + Ok(body_str) => body_str, + Err(_) => return Ok(bad_request_body_res), + }; + + let v2_compat_body = format!("{body_str}\n{query}"); + let id = ShortId::from_str(id)?; + self.pool + .push_default(&id, v2_compat_body.into()) + .await + .map_err(|e| HandlerError::BadRequest(e.into()))?; + handle_peek(self.pool.peek_v1(&id).await, none_response) + } + + async fn handle_ohttp_gateway_get( + &self, + query: &str, + ) -> Result>, HandlerError> { + match query { + "allowed_purposes" => Ok(self.get_ohttp_allowed_purposes().await), + _ => self.get_ohttp_keys().await, + } + } + + async fn get_ohttp_keys(&self) -> Result>, HandlerError> { + let ohttp_keys = self + .ohttp + .config() + .encode() + .map_err(|e| HandlerError::InternalServerError(e.into()))?; + let mut res = Response::new(full(ohttp_keys)); + res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys")); + Ok(res) + } + + async fn get_ohttp_allowed_purposes(&self) -> Response> { + // Encode the magic string in the same format as a TLS ALPN protocol list (a + // U16BE length encoded list of U8 length encoded strings). + // + // The string is just "BIP77" followed by a UUID, that signals to relays + // that this OHTTP gateway will accept any requests associated with this + // purpose. + let mut res = Response::new(full(Bytes::from_static( + b"\x00\x01\x2aBIP77 454403bb-9f7b-4385-b31f-acd2dae20b7e", + ))); + + res.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("application/x-ohttp-allowed-purposes")); + + res } - let full_body = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - bhttp_res.write_content(&full_body); - let mut bhttp_bytes = Vec::new(); - bhttp_res - .write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes) - .map_err(|e| HandlerError::InternalServerError(e.into()))?; - bhttp_bytes.resize(BHTTP_REQ_BYTES, 0); - let ohttp_res = res_ctx - .encapsulate(&bhttp_bytes) - .map_err(|e| HandlerError::InternalServerError(e.into()))?; - assert!(ohttp_res.len() == ENCAPSULATED_MESSAGE_BYTES, "Unexpected OHTTP response size"); - Ok(Response::new(full(ohttp_res))) } -async fn handle_v2( - pool: DbPool, - req: Request>, +fn handle_peek( + result: db::Result>, + timeout_response: Response>, ) -> Result>, HandlerError> { - let path = req.uri().path().to_string(); - let (parts, body) = req.into_parts(); - - let path_segments: Vec<&str> = path.split('/').collect(); - debug!("handle_v2: {:?}", &path_segments); - match (parts.method, path_segments.as_slice()) { - (Method::POST, &["", id]) => post_mailbox(id, body, pool).await, - (Method::GET, &["", id]) => get_mailbox(id, pool).await, - (Method::PUT, &["", id]) => put_payjoin_v1(id, body, pool).await, - _ => Ok(not_found()), + match result { + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + db::Error::Redis(re) => { + error!("Redis error: {}", re); + Err(HandlerError::InternalServerError(anyhow::Error::msg("Internal server error"))) + } + db::Error::Timeout(_) => Ok(timeout_response), + }, } } @@ -375,153 +478,12 @@ impl From for HandlerError { } } -fn handle_peek( - result: db::Result>, - timeout_response: Response>, -) -> Result>, HandlerError> { - match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => match e { - db::Error::Redis(re) => { - error!("Redis error: {}", re); - Err(HandlerError::InternalServerError(anyhow::Error::msg("Internal server error"))) - } - db::Error::Timeout(_) => Ok(timeout_response), - }, - } -} - -async fn post_fallback_v1( - id: &str, - query: String, - body: impl Body, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("Post fallback v1"); - let none_response = Response::builder() - .status(StatusCode::SERVICE_UNAVAILABLE) - .body(full(V1_UNAVAILABLE_RES_JSON))?; - let bad_request_body_res = - Response::builder().status(StatusCode::BAD_REQUEST).body(full(V1_REJECT_RES_JSON))?; - - let body_bytes = match body.collect().await { - Ok(bytes) => bytes.to_bytes(), - Err(_) => return Ok(bad_request_body_res), - }; - - let body_str = match String::from_utf8(body_bytes.to_vec()) { - Ok(body_str) => body_str, - Err(_) => return Ok(bad_request_body_res), - }; - - let v2_compat_body = format!("{body_str}\n{query}"); - let id = ShortId::from_str(id)?; - pool.push_default(&id, v2_compat_body.into()) - .await - .map_err(|e| HandlerError::BadRequest(e.into()))?; - handle_peek(pool.peek_v1(&id).await, none_response) -} - -async fn put_payjoin_v1( - id: &str, - body: BoxBody, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("Put_payjoin_v1"); - let ok_response = Response::builder().status(StatusCode::OK).body(empty())?; - - let id = ShortId::from_str(id)?; - let req = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - if req.len() > V1_MAX_BUFFER_SIZE { - return Err(HandlerError::PayloadTooLarge); - } - - match pool.push_v1(&id, req.into()).await { - Ok(_) => Ok(ok_response), - Err(e) => Err(HandlerError::BadRequest(e.into())), - } -} - -async fn post_mailbox( - id: &str, - body: BoxBody, - pool: DbPool, -) -> Result>, HandlerError> { - let none_response = Response::builder().status(StatusCode::OK).body(empty())?; - trace!("post_mailbox"); - - let id = ShortId::from_str(id)?; - - let req = - body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes(); - if req.len() > V1_MAX_BUFFER_SIZE { - return Err(HandlerError::PayloadTooLarge); - } - - match pool.push_default(&id, req.into()).await { - Ok(_) => Ok(none_response), - Err(e) => Err(HandlerError::InternalServerError(e.into())), - } -} - -async fn get_mailbox( - id: &str, - pool: DbPool, -) -> Result>, HandlerError> { - trace!("get_mailbox"); - let id = ShortId::from_str(id)?; - let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; - handle_peek(pool.peek_default(&id).await, timeout_response) -} - fn not_found() -> Response> { let mut res = Response::default(); *res.status_mut() = StatusCode::NOT_FOUND; res } -async fn handle_ohttp_gateway_get( - ohttp: &Arc>, - query: &str, -) -> Result>, HandlerError> { - match query { - "allowed_purposes" => Ok(get_ohttp_allowed_purposes().await), - _ => get_ohttp_keys(ohttp).await, - } -} - -async fn get_ohttp_keys( - ohttp: &Arc>, -) -> Result>, HandlerError> { - let ohttp_keys = ohttp - .lock() - .await - .config() - .encode() - .map_err(|e| HandlerError::InternalServerError(e.into()))?; - let mut res = Response::new(full(ohttp_keys)); - res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys")); - Ok(res) -} - -async fn get_ohttp_allowed_purposes() -> Response> { - // Encode the magic string in the same format as a TLS ALPN protocol list (a - // U16BE length encoded list of U8 length encoded strings). - // - // The string is just "BIP77" followed by a UUID, that signals to relays - // that this OHTTP gateway will accept any requests associated with this - // purpose. - let mut res = Response::new(full(Bytes::from_static( - b"\x00\x01\x2aBIP77 454403bb-9f7b-4385-b31f-acd2dae20b7e", - ))); - - res.headers_mut() - .insert(CONTENT_TYPE, HeaderValue::from_static("application/x-ohttp-allowed-purposes")); - - res -} - fn empty() -> BoxBody { Empty::::new().map_err(|never| match never {}).boxed() } diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index ca3f7fb61..f4f8d5948 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -1,13 +1,15 @@ use std::env; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use payjoin_directory::*; +use tokio::net::TcpListener; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::EnvFilter; const DEFAULT_KEY_CONFIG_DIR: &str = "ohttp_keys"; #[tokio::main] -async fn main() -> Result<(), Box> { +async fn main() -> Result<(), BoxError> { init_logging(); let dir_port = @@ -36,7 +38,15 @@ async fn main() -> Result<(), Box> { } }; - payjoin_directory::listen_tcp(dir_port, db_host, timeout, ohttp.into()).await + let listener = bind_port(dir_port).await?; + let db = DbPool::new(timeout, db_host).await?; + let service = Service::new(db, ohttp.into()); + service.serve_tcp(listener).await +} + +async fn bind_port(port: u16) -> Result { + let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); + TcpListener::bind(bind_addr).await } fn init_logging() { diff --git a/payjoin-test-utils/src/lib.rs b/payjoin-test-utils/src/lib.rs index e084032b7..7adc9e895 100644 --- a/payjoin-test-utils/src/lib.rs +++ b/payjoin-test-utils/src/lib.rs @@ -1,4 +1,5 @@ use std::env; +use std::net::{IpAddr, Ipv6Addr, SocketAddr}; use std::result::Result; use std::str::FromStr; use std::sync::Arc; @@ -20,6 +21,7 @@ use rustls::pki_types::CertificateDer; use rustls::RootCertStore; use testcontainers::{clients, Container}; use testcontainers_modules::redis::{Redis, REDIS_PORT}; +use tokio::net::TcpListener; use tokio::task::JoinHandle; use tracing_subscriber::{EnvFilter, FmtSubscriber}; use url::Url; @@ -132,16 +134,24 @@ pub async fn init_directory( (u16, tokio::task::JoinHandle>), BoxSendSyncError, > { - println!("Database running on {db_host}"); let timeout = Duration::from_secs(2); let ohttp_server = payjoin_directory::gen_ohttp_server_config()?; - payjoin_directory::listen_tcp_with_tls_on_free_port( - db_host, - timeout, - local_cert_key, - ohttp_server.into(), - ) - .await + + println!("Database running on {db_host}"); + let db = payjoin_directory::DbPool::new(timeout, db_host).await?; + let service = payjoin_directory::Service::new(db, ohttp_server.into()); + + let listener = bind_free_port().await?; + let port = listener.local_addr()?.port(); + + let handle = tokio::spawn(service.serve_tls(listener, local_cert_key)); + + Ok((port, handle)) +} + +async fn bind_free_port() -> Result { + let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 0); + TcpListener::bind(bind_addr).await } /// generate or get a DER encoded localhost cert and key.