-
Notifications
You must be signed in to change notification settings - Fork 92
Gate V1 protocol behind runtime feature flag #1336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,8 @@ const V1_MAX_BUFFER_SIZE: usize = 65536; | |
| const V1_REJECT_RES_JSON: &str = | ||
| r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#; | ||
| const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#; | ||
| const V1_VERSION_UNSUPPORTED_RES_JSON: &str = | ||
| r#"{"errorCode": "version-unsupported", "supported": [2], "message": "V1 is not supported"}"#; | ||
|
|
||
| pub(crate) mod db; | ||
|
|
||
|
|
@@ -68,6 +70,7 @@ pub struct Service<D: Db> { | |
| db: D, | ||
| ohttp: ohttp::Server, | ||
| sentinel_tag: SentinelTag, | ||
| enable_v1: bool, | ||
| } | ||
|
|
||
| impl<D: Db, B> tower::Service<Request<B>> for Service<D> | ||
|
|
@@ -91,8 +94,8 @@ where | |
| } | ||
|
|
||
| impl<D: Db> Service<D> { | ||
| pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag) -> Self { | ||
| Self { db, ohttp, sentinel_tag } | ||
| pub fn new(db: D, ohttp: ohttp::Server, sentinel_tag: SentinelTag, enable_v1: bool) -> Self { | ||
| Self { db, ohttp, sentinel_tag, enable_v1 } | ||
| } | ||
|
|
||
| #[cfg(feature = "_manual-tls")] | ||
|
|
@@ -214,7 +217,7 @@ impl<D: Db> Service<D> { | |
| self.handle_ohttp_gateway_get(&query).await, | ||
| (Method::POST, ["", ""]) => self.handle_ohttp_gateway(body).await, | ||
| (Method::GET, ["", "ohttp-keys"]) => self.get_ohttp_keys().await, | ||
| (Method::POST, ["", id]) => self.post_fallback_v1(id, query, body).await, | ||
| (Method::POST, ["", id]) => self.handle_post_v1(id, query, body).await, | ||
| (Method::GET, ["", "health"]) => health_check().await, | ||
| (Method::GET, ["", ""]) => handle_directory_home_path().await, | ||
| _ => Ok(not_found()), | ||
|
|
@@ -227,6 +230,28 @@ impl<D: Db> Service<D> { | |
| Ok(response) | ||
| } | ||
|
|
||
| /// Route POST /{id}: forward to V1 fallback when enabled, otherwise reject. | ||
| async fn handle_post_v1<B>( | ||
| &self, | ||
| id: &str, | ||
| query: String, | ||
| body: B, | ||
| ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> | ||
| where | ||
| B: Body<Data = Bytes> + Send + 'static, | ||
| B::Error: Into<BoxError>, | ||
| { | ||
| if self.enable_v1 { | ||
| self.post_fallback_v1(id, query, body).await | ||
| } else { | ||
| let _ = (id, query, body); | ||
| Ok(Response::builder() | ||
| .status(StatusCode::BAD_REQUEST) | ||
| .header(CONTENT_TYPE, "application/json") | ||
| .body(full(V1_VERSION_UNSUPPORTED_RES_JSON))?) | ||
| } | ||
| } | ||
|
|
||
| /// Handle an encapsulated OHTTP request and return an encapsulated response | ||
| async fn handle_ohttp_gateway<B>( | ||
| &self, | ||
|
|
@@ -304,7 +329,7 @@ impl<D: Db> Service<D> { | |
| match (parts.method, path_segments.as_slice()) { | ||
| (Method::POST, &["", id]) => self.post_mailbox(id, body).await, | ||
| (Method::GET, &["", id]) => self.get_mailbox(id).await, | ||
| (Method::PUT, &["", id]) => self.put_payjoin_v1(id, body).await, | ||
| (Method::PUT, &["", id]) if self.enable_v1 => self.put_payjoin_v1(id, body).await, | ||
| _ => Ok(not_found()), | ||
| } | ||
| } | ||
|
|
@@ -603,3 +628,87 @@ fn empty() -> BoxBody<Bytes, hyper::Error> { | |
| fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> { | ||
| Full::new(chunk.into()).map_err(|never| match never {}).boxed() | ||
| } | ||
|
|
||
| #[cfg(test)] | ||
| mod tests { | ||
| use std::time::Duration; | ||
|
|
||
| use http_body_util::BodyExt; | ||
| use hyper::body::Bytes; | ||
| use hyper::{Method, Request, StatusCode}; | ||
| use ohttp_relay::SentinelTag; | ||
| use payjoin::directory::ShortId; | ||
|
|
||
| use super::*; | ||
|
|
||
| async fn test_service(enable_v1: bool) -> Service<FilesDb> { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this have any value to be pulled out into payjoin-test-utils to be used more generally?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in theory we could move these tests to the payjoin-mailroom which already has a similar utility, but for now I think it's inoffensive enough until payjoin-directory gets folded in |
||
| let dir = tempfile::tempdir().expect("tempdir"); | ||
| let db = FilesDb::init(Duration::from_millis(100), dir.keep()).await.expect("db init"); | ||
| let ohttp: ohttp::Server = | ||
| key_config::gen_ohttp_server_config().expect("ohttp config").into(); | ||
| Service::new(db, ohttp, SentinelTag::new([0u8; 32]), enable_v1) | ||
| } | ||
|
|
||
| /// A valid ShortId encoded as bech32 for use in URL paths. | ||
| fn valid_short_id_path() -> String { | ||
| let id = ShortId([0u8; 8]); | ||
| id.to_string() | ||
| } | ||
|
|
||
| async fn collect_body(res: Response<BoxBody<Bytes, hyper::Error>>) -> (StatusCode, String) { | ||
| let (parts, body) = res.into_parts(); | ||
| let bytes = body.collect().await.unwrap().to_bytes(); | ||
| (parts.status, String::from_utf8(bytes.to_vec()).unwrap()) | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn post_v1_when_disabled_returns_version_unsupported() { | ||
| let mut svc = test_service(false).await; | ||
| let id = valid_short_id_path(); | ||
| let req = Request::builder() | ||
| .method(Method::POST) | ||
| .uri(format!("http://localhost/{id}")) | ||
| .body(Full::new(Bytes::from("base64-psbt"))) | ||
| .unwrap(); | ||
|
|
||
| let res = tower::Service::call(&mut svc, req).await.unwrap(); | ||
| let (status, body) = collect_body(res).await; | ||
|
|
||
| assert_eq!(status, StatusCode::BAD_REQUEST); | ||
| assert_eq!(body, V1_VERSION_UNSUPPORTED_RES_JSON); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn post_v1_with_invalid_body_returns_reject() { | ||
| let mut svc = test_service(true).await; | ||
| let id = valid_short_id_path(); | ||
| let req = Request::builder() | ||
| .method(Method::POST) | ||
| .uri(format!("http://localhost/{id}")) | ||
| .body(Full::new(Bytes::from(vec![0xFF, 0xFE]))) | ||
| .unwrap(); | ||
|
|
||
| let res = tower::Service::call(&mut svc, req).await.unwrap(); | ||
| let (status, body) = collect_body(res).await; | ||
|
|
||
| assert_eq!(status, StatusCode::BAD_REQUEST); | ||
| assert_eq!(body, V1_REJECT_RES_JSON); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn post_v1_with_no_receiver_returns_unavailable() { | ||
| let mut svc = test_service(true).await; | ||
| let id = valid_short_id_path(); | ||
| let req = Request::builder() | ||
| .method(Method::POST) | ||
| .uri(format!("http://localhost/{id}")) | ||
| .body(Full::new(Bytes::from("base64-psbt"))) | ||
| .unwrap(); | ||
|
|
||
| let res = tower::Service::call(&mut svc, req).await.unwrap(); | ||
| let (status, body) = collect_body(res).await; | ||
|
|
||
| assert_eq!(status, StatusCode::SERVICE_UNAVAILABLE); | ||
| assert_eq!(body, V1_UNAVAILABLE_RES_JSON); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.