diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5a788010e..1088fd2ce 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -27,6 +27,15 @@ jobs: # The toolchain must be specified manually, as this action ignores the rust-toolchain override # https://github.com/dtolnay/rust-toolchain?tab=readme-ov-file#inputs toolchain: ${{ matrix.rust }} + - name: "Install nix" + uses: cachix/install-nix-action@v25 + with: + nix_path: nixpkgs=channel:nixos-25.11 + # This is necessary for ohttp-relay integration tests + - name: "Add nginxWithStream to PATH" + run: | + nix build .#nginx-with-stream -o nginx + echo "$(pwd)/nginx/bin" >> $GITHUB_PATH - name: "Use cache" uses: Swatinem/rust-cache@v2 - name: Run tests @@ -107,6 +116,15 @@ jobs: - name: "Install toolchain" # rust-cache usage with stable Rust is most effective, as a cache is tied to the Rust version uses: dtolnay/rust-toolchain@stable + - name: "Install nix" + uses: cachix/install-nix-action@v25 + with: + nix_path: nixpkgs=channel:nixos-25.11 + # This is necessary for ohttp-relay integration tests + - name: "Add nginxWithStream to PATH" + run: | + nix build .#nginx-with-stream -o nginx + echo "$(pwd)/nginx/bin" >> $GITHUB_PATH - name: "Use cache" uses: Swatinem/rust-cache@v2 - name: "Install cargo-llvm-cov" diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index 71378aabb..b9007fb07 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -254,6 +254,16 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-compat" version = "0.2.5" @@ -866,6 +876,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "compression-codecs" version = "0.4.30" @@ -1523,6 +1542,25 @@ dependencies = [ "scroll", ] +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1687,6 +1725,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -2085,11 +2124,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -2165,6 +2204,31 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mockito" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e0603425789b4a70fcc4ac4f5a46a566c116ee3e2a6b768dc623f7719c611de" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "pin-project-lite", + "rand 0.9.1", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "nix" version = "0.30.1" @@ -2190,12 +2254,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "overload", - "winapi", + "windows-sys 0.60.2", ] [[package]] @@ -2244,24 +2307,29 @@ dependencies = [ [[package]] name = "ohttp-relay" version = "0.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15a28c940492277c063664f7a7aa9bc881bacf9659aacb06916666ea56af177a" dependencies = [ "byteorder", "bytes", "futures", + "hex-conservative 0.1.2", "http", "http-body-util", "hyper", "hyper-rustls", "hyper-tungstenite", "hyper-util", + "mockito", + "rcgen 0.12.1", + "reqwest", "rustls 0.23.31", + "tempfile", "tokio", + "tokio-rustls", "tokio-tungstenite", "tokio-util", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -2313,12 +2381,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking_lot" version = "0.11.2" @@ -2837,6 +2899,18 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rcgen" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1" +dependencies = [ + "pem", + "ring", + "time", + "yasna", +] + [[package]] name = "rcgen" version = "0.13.2" @@ -2900,17 +2974,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2921,15 +2986,9 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -2952,6 +3011,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "http", @@ -3412,6 +3472,12 @@ dependencies = [ "libc", ] +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "siphasher" version = "0.3.11" @@ -3961,14 +4027,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -4289,12 +4355,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.3", + "js-sys", "rand 0.9.1", + "wasm-bindgen", ] [[package]] diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 71378aabb..b9007fb07 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -254,6 +254,16 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "assert-json-diff" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47e4f2b81832e72834d7518d8487a0396a28cc408186a2e8854c0f98011faf12" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "async-compat" version = "0.2.5" @@ -866,6 +876,15 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" +[[package]] +name = "colored" +version = "3.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fde0e0ec90c9dfb3b4b1a0891a7dcd0e2bffde2f7efed5fe7c9bb00e5bfb915e" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "compression-codecs" version = "0.4.30" @@ -1523,6 +1542,25 @@ dependencies = [ "scroll", ] +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.14.5" @@ -1687,6 +1725,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -2085,11 +2124,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -2165,6 +2204,31 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "mockito" +version = "1.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e0603425789b4a70fcc4ac4f5a46a566c116ee3e2a6b768dc623f7719c611de" +dependencies = [ + "assert-json-diff", + "bytes", + "colored", + "futures-core", + "http", + "http-body", + "http-body-util", + "hyper", + "hyper-util", + "log", + "pin-project-lite", + "rand 0.9.1", + "regex", + "serde_json", + "serde_urlencoded", + "similar", + "tokio", +] + [[package]] name = "nix" version = "0.30.1" @@ -2190,12 +2254,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" dependencies = [ - "overload", - "winapi", + "windows-sys 0.60.2", ] [[package]] @@ -2244,24 +2307,29 @@ dependencies = [ [[package]] name = "ohttp-relay" version = "0.0.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "15a28c940492277c063664f7a7aa9bc881bacf9659aacb06916666ea56af177a" dependencies = [ "byteorder", "bytes", "futures", + "hex-conservative 0.1.2", "http", "http-body-util", "hyper", "hyper-rustls", "hyper-tungstenite", "hyper-util", + "mockito", + "rcgen 0.12.1", + "reqwest", "rustls 0.23.31", + "tempfile", "tokio", + "tokio-rustls", "tokio-tungstenite", "tokio-util", "tracing", "tracing-subscriber", + "uuid", ] [[package]] @@ -2313,12 +2381,6 @@ dependencies = [ "hashbrown 0.14.5", ] -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "parking_lot" version = "0.11.2" @@ -2837,6 +2899,18 @@ dependencies = [ "getrandom 0.3.3", ] +[[package]] +name = "rcgen" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48406db8ac1f3cbc7dcdb56ec355343817958a356ff430259bb07baf7607e1e1" +dependencies = [ + "pem", + "ring", + "time", + "yasna", +] + [[package]] name = "rcgen" version = "0.13.2" @@ -2900,17 +2974,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2921,15 +2986,9 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -2952,6 +3011,7 @@ dependencies = [ "base64 0.22.1", "bytes", "encoding_rs", + "futures-channel", "futures-core", "futures-util", "http", @@ -3412,6 +3472,12 @@ dependencies = [ "libc", ] +[[package]] +name = "similar" +version = "2.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa" + [[package]] name = "siphasher" version = "0.3.11" @@ -3961,14 +4027,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "sharded-slab", "smallvec", "thread_local", @@ -4289,12 +4355,14 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "uuid" -version = "1.16.0" +version = "1.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "458f7a779bf54acc9f347480ac654f68407d3aab21269a6e3c9f922acd9e2da9" +checksum = "e2e054861b4bd027cd373e18e8d8d8e6548085000e41290d95ce0c373a654b4a" dependencies = [ "getrandom 0.3.3", + "js-sys", "rand 0.9.1", + "wasm-bindgen", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 414146911..67e22f9f9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,8 +1,9 @@ [workspace] -members = ["payjoin", "payjoin-cli", "payjoin-directory", "payjoin-test-utils", "payjoin-ffi"] +members = ["ohttp-relay", "payjoin", "payjoin-cli", "payjoin-directory", "payjoin-test-utils", "payjoin-ffi"] resolver = "2" [patch.crates-io] +ohttp-relay = { path = "ohttp-relay" } payjoin = { path = "payjoin" } payjoin-directory = { path = "payjoin-directory" } payjoin-test-utils = { path = "payjoin-test-utils" } diff --git a/contrib/test.sh b/contrib/test.sh index ed072b4eb..c1441afd3 100755 --- a/contrib/test.sh +++ b/contrib/test.sh @@ -24,7 +24,7 @@ if [ -f "$LOCKFILE" ]; then fi DEPS="recent minimal" -CRATES="payjoin payjoin-cli payjoin-directory payjoin-ffi" +CRATES="ohttp-relay payjoin payjoin-cli payjoin-directory payjoin-ffi" for dep in $DEPS; do cargo --version diff --git a/contrib/test_local.sh b/contrib/test_local.sh index 40ccd4ece..3563f8647 100755 --- a/contrib/test_local.sh +++ b/contrib/test_local.sh @@ -1,7 +1,7 @@ #!/usr/bin/env bash set -e -CRATES="payjoin payjoin-cli payjoin-directory payjoin-ffi" +CRATES="ohttp-relay payjoin payjoin-cli payjoin-directory payjoin-ffi" cargo --version rustc --version diff --git a/flake.nix b/flake.nix index b63a83c38..077b90537 100644 --- a/flake.nix +++ b/flake.nix @@ -28,6 +28,14 @@ }; msrv = "1.85.0"; + + nginxWithStream = pkgs.nginxMainline.overrideAttrs (oldAttrs: { + configureFlags = oldAttrs.configureFlags ++ [ + "--with-stream" + "--with-stream_ssl_module" + "--error-log-path=/dev/null" + ]; + }); rustVersions = with pkgs.rust-bin; builtins.mapAttrs @@ -111,6 +119,7 @@ "payjoin" = "--features v2"; "payjoin-cli" = "--features v1,v2"; "payjoin-directory" = ""; + "ohttp-relay" = ""; }; devShells = builtins.mapAttrs ( @@ -124,6 +133,7 @@ cargo-watch rust-analyzer dart + nginxWithStream ] ++ pkgs.lib.optionals (!pkgs.stdenv.isDarwin) [ cargo-llvm-cov @@ -143,7 +153,9 @@ ); in { - packages = packages; + packages = packages // { + nginx-with-stream = nginxWithStream; + }; devShells = devShells // { default = devShells.nightly; }; diff --git a/ohttp-relay/CHANGELOG.md b/ohttp-relay/CHANGELOG.md new file mode 100644 index 000000000..551e44f81 --- /dev/null +++ b/ohttp-relay/CHANGELOG.md @@ -0,0 +1,60 @@ +# ohttp-relay Changelog + +## 0.0.11 + +### MSRV Update and Dependency Modernization + +This release updates the minimum supported Rust version (MSRV) to 1.85.0, following the Debian stable update and aligning with other Bitcoin ecosystem projects. Along with this MSRV bump, all dependencies have been updated to their latest compatible versions. + +#### Key Changes + +- **MSRV bump to 1.85.0** - Updated minimum supported Rust version for better ecosystem alignment +- **Rustls dependency updates** - Updated rustls dependencies while maintaining ring crypto provider support +- **Comprehensive dependency updates** - Updated tokio-tungstenite, hyper-tungstenite, mockito, and other dependencies to match the new MSRV +- **API compatibility** - Adapted to API changes in updated dependencies, particularly mockito and tungstenite + +#### Technical Details + +- Rustls updates continue to use ring's crypto provider by initializing the default provider +- Mockito API updates have been integrated to maintain testing functionality +- Tungstenite WebSocket implementation updated for improved compatibility +- All dependency updates maintain backward compatibility with existing functionality + +## 0.0.10 + +### Enable opt-in Gateway reachability for BIP 77 + +The [BIP 77 Draft](https://github.com/bitcoin/bips/pull/1483) imagines clients reach one another +over a "mailbox" store-and-forward server through OHTTP Relays. In order for Relays to reach those +mailbox servers without being pre-defined, this release includes support for an opt-in mechanism +based on [RFC 9540](https://www.rfc-editor.org/rfc/rfc9540.html)'s Oblivious Gateway discovery +mechanism augmented with an `allowed_purposes` parameter that may specify the BIP 77 mailbox as a +specific service. + +This was activated by implementing probing functionality that caches `allowed_purposes` responses +to prevent this Relay from being party to denial of service attacks where a client might spam +requests to Gateways that do not support an allowed purpose. + +- RFC 9540 was implemented in [#47](https://github.com/payjoin/ohttp-relay/pull/47) +- RFC 9458 behavior was corrected in [#46](https://github.com/payjoin/ohttp-relay/pull/46) +- Internal abstractions and ergonomics were improved in [#50](https://github.com/payjoin/ohttp-relay/pull/50), [#57](https://github.com/payjoin/ohttp-relay/pull/57), [#59](https://github.com/payjoin/ohttp-relay/pull/59), [#60](https://github.com/payjoin/ohttp-relay/pull/60), [#62](https://github.com/payjoin/ohttp-relay/pull/62), and [#63](https://github.com/payjoin/ohttp-relay/pull/63). +- Gateway opt-in was introduced in [#58](https://github.com/payjoin/ohttp-relay/pull/58) + +### Gateway Probing and BIP77 Support +- Added gateway probing functionality with caching mechanism for improved performance [#46](https://github.com/payjoin/ohttp-relay/pull/46) +Implemented BIP77 purpose string detection in allowed purposes response #47 +Added ALPN-encoded format parsing for gateway allowed purposes #50 + +- https://github.com/payjoin/ohttp-relay/pull/46 +- https://github.com/payjoin/ohttp-relay/pull/47 +- https://github.com/payjoin/ohttp-relay/pull/50 +- https://github.com/payjoin/ohttp-relay/pull/57 +- https://github.com/payjoin/ohttp-relay/pull/58 +- https://github.com/payjoin/ohttp-relay/pull/59 +- https://github.com/payjoin/ohttp-relay/pull/60 +- https://github.com/payjoin/ohttp-relay/pull/62 +- https://github.com/payjoin/ohttp-relay/pull/63 + +## 0.0.9 + +- Add `_test-util` feature to allow testing with `listen_tcp_on_free_port` diff --git a/ohttp-relay/Cargo.toml b/ohttp-relay/Cargo.toml new file mode 100644 index 000000000..31140b476 --- /dev/null +++ b/ohttp-relay/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "ohttp-relay" +version = "0.0.11" +authors = ["Dan Gould "] +description = "Relay Oblivious HTTP requests to protect IP metadata" +repository = "https://github.com/payjoin/rust-payjoin/tree/master/ohttp-relay" +readme = "README.md" +keywords = ["ohttp", "privacy"] +categories = ["web-programming", "network-programming"] +license = "MITNFA" +edition = "2021" +rust-version = "1.85.0" +exclude = ["tests"] + +[features] +default = ["bootstrap"] +bootstrap = ["connect-bootstrap", "ws-bootstrap"] +connect-bootstrap = [] +ws-bootstrap = ["futures", "hyper-tungstenite", "rustls", "tokio-tungstenite"] +_test-util = [] + +[dependencies] +byteorder = "1.5.0" +bytes = "1.10.1" +futures = { version = "0.3.31", optional = true } +http = "1.3.1" +http-body-util = "0.1.3" +hyper = { version = "1.6.0", features = ["http1", "server"] } +hyper-rustls = { version = "0.27.7", default-features=false, features = ["webpki-roots", "http1", "ring"] } +hyper-tungstenite = { version = "0.18.0", optional = true } +hyper-util = { version = "0.1.16", features = ["client-legacy"] } +rustls = { version = "0.23.31", optional = true, default-features=false, features = ["ring"] } +tokio = { version = "1.47.1", features = ["io-std", "macros", "net", "rt-multi-thread"] } +tokio-tungstenite = { version = "0.27.0", optional = true } +tokio-util = { version = "0.7.16", features = ["net", "codec"] } +tracing = "0.1.41" +tracing-subscriber = { version = "0.3.20", features = ["env-filter"] } + +[dev-dependencies] +hex = { package = "hex-conservative", version = "0.1.1" } +mockito = "1.7.0" +rcgen = "0.12" +tempfile = "3.20.0" +tokio = { version = "1.47.1", features = ["process", "test-util"] } +tokio-rustls = { version = "0.26.2", default-features=false, features = ["ring"]} +reqwest = { version = "0.12.23", default-features= false ,features = ["rustls-tls", "blocking"] } +uuid = { version = "1.18.0", features = ["v4"] } diff --git a/ohttp-relay/Dockerfile b/ohttp-relay/Dockerfile new file mode 100644 index 000000000..ec6c2d75d --- /dev/null +++ b/ohttp-relay/Dockerfile @@ -0,0 +1,28 @@ +# Stage 1: Building the binary +FROM nixos/nix:2.20.5 AS builder + +# Copy our source and setup our working dir. +COPY . /tmp/build +WORKDIR /tmp/build + +# Build our Nix environment +RUN nix \ + --extra-experimental-features "nix-command flakes" \ + --option filter-syscalls false \ + build + +# Copy the Nix store closure into a directory. The Nix store closure is the +# entire set of Nix store values that we need for our build. +RUN mkdir /tmp/nix-store-closure +RUN cp -R $(nix-store -qR result/) /tmp/nix-store-closure + +# Final image is based on scratch. We copy a bunch of Nix dependencies +# but they're fully self-contained so we don't need Nix anymore. +FROM scratch + +WORKDIR /ohttp-relay + +# Copy /nix/store +COPY --from=builder /tmp/nix-store-closure /nix/store +COPY --from=builder /tmp/build/result /ohttp-relay +CMD ["/ohttp-relay/bin/ohttp-relay"] diff --git a/ohttp-relay/README.md b/ohttp-relay/README.md new file mode 100644 index 000000000..79e073446 --- /dev/null +++ b/ohttp-relay/README.md @@ -0,0 +1,27 @@ +# OHTTP Relay + +A rust implementation of an [Oblivious +HTTP](https://ietf-wg-ohai.github.io/oblivious-http/draft-ietf-ohai-ohttp.html) relay resource. + +This work is undergoing active revision in the IETF and so are these +implementations. Use at your own risk. + +## Usage + +Run ohttp-relay by setting `PORT` and `GATEWAY_ORIGIN` environment variables. For example, to relay from port 3000 to an OHTTP Gateway Resource at `https://payjo.in`, run the following. + +```console +PORT=3000 GATEWAY_ORIGIN='https://payjo.in' cargo run +``` + +Alternatively, set `UNIX_SOCKET` to bind to a unix socket path instead of a TCP port. + +This crate is intended to be run behind a reverse proxy like NGINX that can handle TLS for you. Tests specifically cover this integration using `nginx.conf.template`. + +## Bootstrap Feature + +The Oblivious HTTP specification requires clients obtain a [Key Configuration](https://www.ietf.org/rfc/rfc9458.html#name-key-configuration) from the OHTTP Gateway but leaves a mechanism for doing so explicitly unspecified. This feature hosts HTTPS-in-WebSocket and HTTPS-in-CONNECT proxies to allow web clients to GET a gateway's ohttp-keys via [Direct Discovery](https://datatracker.ietf.org/doc/html/draft-ietf-privacypass-key-consistency-01#name-direct-discovery) in an end-to-end-encrypted, authenticated manner using the OHTTP relay as a tunnel so as not to reveal their IP address. The `bootstrap` feature to host these proxies is enabled by default. The `ws-bootstrap` and `connect-bootstrap` features enable each proxy individually. + +### How does it work? + +Both bootstrap features enable the server to forward packets directly to and from the OHTTP Gateway's TCP socket to negotiate a TLS session between the client and gateway. By doing so, the OHTTP Relay is prevented from conducting a [man-in-the-middle attack](https://en.wikipedia.org/wiki/Man-in-the-middle_attack) to compromise the TLS session. diff --git a/ohttp-relay/contrib/test.sh b/ohttp-relay/contrib/test.sh new file mode 100755 index 000000000..070ba97e5 --- /dev/null +++ b/ohttp-relay/contrib/test.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +set -e + +cargo test --locked --package ohttp-relay --verbose --all-features --lib +cargo test --locked --package ohttp-relay --verbose --all-features --test integration diff --git a/ohttp-relay/nginx.conf.template b/ohttp-relay/nginx.conf.template new file mode 100644 index 000000000..bd17fc1ba --- /dev/null +++ b/ohttp-relay/nginx.conf.template @@ -0,0 +1,23 @@ +error_log {{error_log_path}} debug; +pid {{pid_path}}; + +events { + worker_connections 1024; +} + +stream { + server { + listen {{http_port}}; + + proxy_pass {{proxy_pass}}; + } + + server { + listen {{https_port}} ssl; + + ssl_certificate {{cert_path}}; + ssl_certificate_key {{key_path}}; + + proxy_pass {{proxy_pass}}; + } +} \ No newline at end of file diff --git a/ohttp-relay/src/bootstrap/connect.rs b/ohttp-relay/src/bootstrap/connect.rs new file mode 100644 index 000000000..3e57e45cf --- /dev/null +++ b/ohttp-relay/src/bootstrap/connect.rs @@ -0,0 +1,51 @@ +use std::net::SocketAddr; + +use http_body_util::combinators::BoxBody; +use hyper::body::{Bytes, Incoming}; +use hyper::upgrade::Upgraded; +use hyper::{Method, Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::net::TcpStream; +use tracing::{error, instrument}; + +use crate::error::Error; +use crate::{empty, GatewayUri}; + +pub(crate) fn is_connect_request(req: &Request) -> bool { + Method::CONNECT == req.method() +} + +#[instrument] +pub(crate) async fn try_upgrade( + req: Request, + gateway_origin: GatewayUri, +) -> Result>, Error> { + let addr = gateway_origin + .to_socket_addr() + .await + .map_err(|e| Error::InternalServerError(Box::new(e)))? + .ok_or_else(|| Error::NotFound)?; + + tokio::task::spawn(async move { + match hyper::upgrade::on(req).await { + Ok(upgraded) => { + if let Err(e) = tunnel(upgraded, addr).await { + error!("server io error: {}", e); + }; + } + Err(e) => error!("upgrade error: {}", e), + } + }); + + Ok(Response::new(empty())) +} + +/// Create a TCP connection to host:port, build a tunnel between the connection and +/// the upgraded connection +#[instrument] +async fn tunnel(upgraded: Upgraded, addr: SocketAddr) -> std::io::Result<()> { + let mut server = TcpStream::connect(addr).await?; + let mut upgraded = TokioIo::new(upgraded); + let (_, _) = tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?; + Ok(()) +} diff --git a/ohttp-relay/src/bootstrap/mod.rs b/ohttp-relay/src/bootstrap/mod.rs new file mode 100644 index 000000000..f3ae187c2 --- /dev/null +++ b/ohttp-relay/src/bootstrap/mod.rs @@ -0,0 +1,31 @@ +use http_body_util::combinators::BoxBody; +use hyper::body::{Bytes, Incoming}; +use hyper::{Request, Response}; +use tracing::instrument; + +use crate::error::Error; +use crate::GatewayUri; + +#[cfg(feature = "connect-bootstrap")] +pub mod connect; + +#[cfg(feature = "ws-bootstrap")] +pub mod ws; + +#[instrument] +pub(crate) async fn handle_ohttp_keys( + mut req: Request, + gateway_origin: GatewayUri, +) -> Result>, Error> { + #[cfg(feature = "connect-bootstrap")] + if connect::is_connect_request(&req) { + return connect::try_upgrade(req, gateway_origin).await; + } + + #[cfg(feature = "ws-bootstrap")] + if ws::is_websocket_request(&req) { + return ws::try_upgrade(&mut req, gateway_origin).await; + } + + Err(Error::BadRequest("Not a supported proxy upgrade request".to_string())) +} diff --git a/ohttp-relay/src/bootstrap/ws.rs b/ohttp-relay/src/bootstrap/ws.rs new file mode 100644 index 000000000..feffa0fd1 --- /dev/null +++ b/ohttp-relay/src/bootstrap/ws.rs @@ -0,0 +1,170 @@ +use std::io; +use std::net::SocketAddr; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Sink, SinkExt, StreamExt}; +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use hyper::body::{Bytes, Incoming}; +use hyper::{Request, Response}; +use hyper_tungstenite::HyperWebsocket; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_tungstenite::tungstenite::protocol::Message; +use tokio_tungstenite::{tungstenite, WebSocketStream}; +use tracing::{error, instrument}; + +use crate::error::Error; +use crate::gateway_uri::GatewayUri; + +pub(crate) fn is_websocket_request(req: &Request) -> bool { + hyper_tungstenite::is_upgrade_request(req) +} + +#[instrument] +pub(crate) async fn try_upgrade( + req: &mut Request, + gateway_origin: GatewayUri, +) -> Result>, Error> { + let gateway_addr = gateway_origin + .to_socket_addr() + .await + .map_err(|e| Error::InternalServerError(Box::new(e)))? + .ok_or_else(|| Error::NotFound)?; + + let (res, websocket) = hyper_tungstenite::upgrade(req, None) + .map_err(|e| Error::BadRequest(format!("Error upgrading to websocket: {}", e)))?; + + tokio::spawn(async move { + if let Err(e) = serve_websocket(websocket, gateway_addr).await { + error!("Error in websocket connection: {e}"); + } + }); + let (parts, body) = res.into_parts(); + let boxbody = body.map_err(|never| match never {}).boxed(); + Ok(Response::from_parts(parts, boxbody)) +} + +/// Stream WebSocket frames from the client to the gateway server's TCP socket and vice versa. +#[instrument] +async fn serve_websocket( + websocket: HyperWebsocket, + gateway_addr: SocketAddr, +) -> Result<(), Box> { + let mut tcp_stream = tokio::net::TcpStream::connect(gateway_addr).await?; + let mut ws_io = WsIo::new(websocket.await?); + let (_, _) = tokio::io::copy_bidirectional(&mut ws_io, &mut tcp_stream).await?; + Ok(()) +} + +pub struct WsIo +where + S: AsyncRead + AsyncWrite + Unpin, +{ + ws_stream: WebSocketStream, + read_buffer: Vec, +} + +impl WsIo +where + S: AsyncRead + AsyncWrite + Unpin, +{ + pub fn new(ws_stream: WebSocketStream) -> Self { + WsIo { ws_stream, read_buffer: Vec::new() } + } +} + +impl AsyncRead for WsIo +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + let self_mut = self.get_mut(); + + // If the read buffer has data, use it first. + if !self_mut.read_buffer.is_empty() { + let len = std::cmp::min(buf.remaining(), self_mut.read_buffer.len()); + buf.put_slice(&self_mut.read_buffer[..len]); + self_mut.read_buffer.drain(..len); + return Poll::Ready(Ok(())); + } + // Otherwise, try to read a new frame. + match self_mut.ws_stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(message))) => match message { + Message::Binary(data) => { + self_mut.read_buffer.extend_from_slice(&data); + let len = std::cmp::min(buf.remaining(), self_mut.read_buffer.len()); + buf.put_slice(&self_mut.read_buffer[..len]); + self_mut.read_buffer.drain(..len); + Poll::Ready(Ok(())) + } + Message::Ping(data) => start_send(&mut self_mut.ws_stream, Message::Pong(data)), + Message::Pong(_) => { + // Usually, no action is needed on pong messages + Poll::Pending + } + Message::Close(_) => start_send(&mut self_mut.ws_stream, Message::Close(None)), + _ => Poll::Pending, + }, + Poll::Ready(None) => { + // No more messages will be received because the WebSocket stream is closed. + // If there's no data left in the read buffer, we signify EOF by returning Ok. + if self_mut.read_buffer.is_empty() { + Poll::Ready(Ok(())) // Signify EOF + } else { + // If there's still data left in the buffer, we need to return that first. + // This ensures that the caller can consume all remaining data before receiving EOF. + let len = std::cmp::min(buf.remaining(), self_mut.read_buffer.len()); + buf.put_slice(&self_mut.read_buffer[..len]); + self_mut.read_buffer.drain(..len); + Poll::Ready(Ok(())) + } + } + Poll::Ready(Some(Err(e))) => Poll::Ready(Err(map_ws_error(e))), + Poll::Pending => Poll::Pending, + } + } +} + +impl AsyncWrite for WsIo +where + S: AsyncRead + AsyncWrite + Unpin, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + data: &[u8], + ) -> Poll> { + let self_mut = self.get_mut(); + match Pin::new(&mut self_mut.ws_stream).poll_ready(cx) { + Poll::Ready(Ok(())) => + start_send(&mut self_mut.ws_stream, Message::Binary(data.to_vec().into())) + .map(|r| r.map(|_| data.len())), + Poll::Ready(Err(e)) => Poll::Ready(Err(map_ws_error(e))), + Poll::Pending => Poll::Pending, + } + } + + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().ws_stream).poll_flush(cx).map_err(map_ws_error) + } + + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + Pin::new(&mut self.get_mut().ws_stream).poll_close(cx).map_err(map_ws_error) + } +} + +fn start_send( + ws_stream: &mut WebSocketStream, + data: Message, +) -> Poll> { + Poll::Ready(ws_stream.start_send_unpin(data).map_err(map_ws_error)) +} + +fn map_ws_error(e: tungstenite::Error) -> io::Error { + io::Error::new(io::ErrorKind::BrokenPipe, format!("Tungstenite error: {}", e)) +} diff --git a/ohttp-relay/src/error.rs b/ohttp-relay/src/error.rs new file mode 100644 index 000000000..07f73c84d --- /dev/null +++ b/ohttp-relay/src/error.rs @@ -0,0 +1,69 @@ +use std::time::Duration; + +use http_body_util::combinators::BoxBody; +use http_body_util::BodyExt; +use hyper::body::Bytes; +use hyper::header::{HeaderValue, RETRY_AFTER}; +use hyper::{Response, StatusCode}; +use tracing::error; + +use crate::{empty, full}; + +pub(crate) type BoxError = Box; + +#[derive(Debug)] +#[allow(clippy::enum_variant_names)] +pub(crate) enum Error { + BadGateway, + MethodNotAllowed, + UnsupportedMediaType, + BadRequest(String), + NotFound, + InternalServerError(BoxError), + Unavailable(Duration), +} + +impl Error { + pub fn to_response(&self) -> Response> { + let mut res = Response::new(empty()); + match self { + Self::UnsupportedMediaType => *res.status_mut() = StatusCode::UNSUPPORTED_MEDIA_TYPE, + Self::BadGateway => *res.status_mut() = StatusCode::BAD_GATEWAY, + Self::MethodNotAllowed => *res.status_mut() = StatusCode::METHOD_NOT_ALLOWED, + Self::BadRequest(e) => { + *res.status_mut() = StatusCode::BAD_REQUEST; + *res.body_mut() = full(e.to_string()).boxed(); + } + Self::NotFound => *res.status_mut() = StatusCode::NOT_FOUND, + Self::InternalServerError(internal_error) => { + error!("Internal server error: {}", internal_error); + *res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + } + Self::Unavailable(max_age) => { + *res.status_mut() = StatusCode::SERVICE_UNAVAILABLE; + res.headers_mut().append( + RETRY_AFTER, + HeaderValue::from_str(&max_age.as_secs().to_string()) + .expect("header value should always be valid"), + ); + } + }; + res + } +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::UnsupportedMediaType => write!(f, "Unsupported media type"), + Self::BadGateway => write!(f, "Bad gateway"), + Self::MethodNotAllowed => write!(f, "Method not allowed"), + Self::BadRequest(e) => write!(f, "Bad request: {}", e), + Self::NotFound => write!(f, "Not found"), + Self::InternalServerError(e) => write!(f, "Internal server error: {}", e), + Self::Unavailable(_) => write!(f, "Service unavailable"), + } + } +} + +impl std::error::Error for Error {} diff --git a/ohttp-relay/src/gateway_prober.rs b/ohttp-relay/src/gateway_prober.rs new file mode 100644 index 000000000..7f0234ec0 --- /dev/null +++ b/ohttp-relay/src/gateway_prober.rs @@ -0,0 +1,768 @@ +use std::cmp::{Ordering, Reverse}; +use std::collections::{BinaryHeap, HashMap}; +use std::error::Error; +use std::io::{ErrorKind, Read}; +use std::time::Duration; + +use byteorder::{BigEndian, ReadBytesExt}; +use bytes::BytesMut; +use futures::future::{self, FutureExt}; +use http_body_util::BodyExt; +use hyper::body::Incoming; +use tokio::sync::{oneshot, RwLock}; +use tokio::time::Instant; + +use crate::gateway_uri::GatewayUri; + +// these are only pub for the integration test +pub const MAGIC_BIP77_PURPOSE: &[u8] = b"BIP77 454403bb-9f7b-4385-b31f-acd2dae20b7e"; +pub const ALLOWED_PURPOSES_CONTENT_TYPE: &str = "application/x-ohttp-allowed-purposes"; +const DEFAULT_CAPACITY: usize = 1000; + +#[derive(PartialEq, Eq, Clone, Copy, Debug)] +pub(crate) struct Policy { + pub(crate) bip77_allowed: bool, + pub(crate) expires: Instant, +} + +impl Policy { + fn always(bip77_allowed: bool) -> Self { + // Rationale for thirty years is same as tokio's Instant::far_future, + // this value is portable and will not overflow for foreseeable future + const THIRTY_YEARS: Duration = Duration::from_secs(30 * 365 * 24 * 60 * 60); + let expires = Instant::now() + THIRTY_YEARS; + Self { bip77_allowed, expires } + } +} + +#[derive(Debug)] +enum Status { + InFlight(future::Shared>), + Known(Policy), +} + +#[derive(Default, Debug)] +pub(crate) struct Prober { + gateways: RwLock, + ttl_config: TTLConfig, + client: super::HttpClient, +} + +#[derive(Debug)] +struct KnownGateways { + capacity: usize, + by_url: HashMap, + by_expiry: BinaryHeap, +} + +#[derive(PartialEq, Eq, Debug)] +struct HeapEntry { + expires: Instant, + key: GatewayUri, +} + +impl Ord for HeapEntry { + /// Reverse ordering by expires for min-heap semantics + fn cmp(&self, other: &Self) -> Ordering { Reverse(self.expires).cmp(&Reverse(other.expires)) } +} + +impl PartialOrd for HeapEntry { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } +} + +impl Default for KnownGateways { + fn default() -> Self { + Self { + capacity: DEFAULT_CAPACITY, + by_url: HashMap::default(), + by_expiry: BinaryHeap::default(), + } + } +} + +impl KnownGateways { + fn get(&mut self, url: &GatewayUri) -> Option<&Status> { + // eager pruning because the borrow checker gets upset by the commented + // out lazy version below + self.prune(); + + let status = self.by_url.get(url); + + // if let Some(GatewayStatus::Known(policy)) = status { + // status = None; // does not appease borrow checker + // if !policy.expires.elapsed().is_zero() { + // self.prune(); + // return None; + // } + // } + + status + } + + fn prune(&mut self) { + debug_assert!(self.by_expiry.len() <= self.by_url.len()); + while let Some(entry) = self.by_expiry.peek() { + if entry.expires.elapsed().is_zero() { + break; + } + + self.by_url.remove(&entry.key); + self.by_expiry.pop(); + } + debug_assert!(self.by_expiry.len() <= self.by_url.len()); + } + + fn has_capacity(&mut self) -> bool { + self.prune(); + + self.by_url.len() < self.capacity + } + + fn no_capacity_for(&mut self) -> Duration { + if self.has_capacity() { + return Duration::ZERO; + } + + self.by_expiry + .peek() + .map(|e| e.expires.saturating_duration_since(Instant::now())) + .unwrap_or(Duration::ZERO) + } + + fn allocate_in_flight(&mut self, uri: &GatewayUri) -> Option> { + if !self.has_capacity() { + return None; + } + + if self.by_url.contains_key(uri) { + return None; + } + + let (sender, receiver) = oneshot::channel::(); + _ = self.by_url.insert(uri.clone(), Status::InFlight(receiver.shared())); + + Some(sender) + } + + fn insert(&mut self, url: &GatewayUri, policy: Policy) -> Option<()> { + // could use try_insert()? but that's an unstable feature + // we want to avoid duplicate insertions because updating TTL requires + // scanning the heap, or having multiple heap entries per key, which + // complicates things unnecessarily. + // however if the existing entry is inflight, that can and should be + // overwritten exactly once. + if let Some(Status::Known(_)) = self.by_url.get(url) { + return None; + } + + debug_assert!(self.by_expiry.len() <= self.by_url.len()); + + // a more robust approach might be to keep the sender in the map as + // well, send() to it here, and ensure that it is the right interface, + // this should be possible since oneshot does not require the sender to + // be async so that should be possible, but still requires using this + // method externally. + // making the entries some kind of atomic pointer implementing the + // equivalent of a haskell LVar can ensure that the state machine per + // entry is always inflight -> inserted, but that seems much more + // complex, so instead we just overwrite any existing entry and tolerate + // inflight ones not being in the map for simplicity. + _ = self.by_url.insert(url.clone(), Status::Known(policy)); + self.by_expiry.push(HeapEntry { expires: policy.expires, key: url.clone() }); + + Some(()) + } +} + +impl Prober { + pub(crate) fn new_with_client(client: super::HttpClient) -> Self { + Self { client, ..Self::default() } + } + + /// Permanently mark a gateway authority as allowed. + pub(crate) async fn assert_opt_in(&self, url: &GatewayUri) -> Option<()> { + let mut locked_map = self.gateways.write().await; + locked_map.insert(url, Policy::always(true)) + } + + /// Check whether a gateway is allowed. If the policy is not known, + /// the gateway will be probed. + pub(crate) async fn check_opt_in(&self, url: &GatewayUri) -> Option { + let inflight = { + let mut locked_map = self.gateways.write().await; + match locked_map.get(url) { + Some(Status::Known(policy)) => return Some(*policy), + Some(Status::InFlight(receiver)) => Ok(receiver.clone()), + None => { + // Only actually query the url if this is the first + // lookup and the map is not over capacity + let sender = locked_map.allocate_in_flight(url)?; + Err(sender) + } + } + }; + + Some(match inflight { + Ok(receiver) => receiver.await.expect("probe task should never be dropped"), + Err(sender) => { + let policy = self.probe(url).await; + + { + let mut locked_map = self.gateways.write().await; + locked_map.insert(url, policy); + } + + _ = sender.send(policy); + + policy + } + }) + } + + async fn is_explicit_opt_in(res: &mut hyper::Response) -> Option<()> { + if res.status() != hyper::StatusCode::OK { + return None; + } + + let mut body = BytesMut::new(); + while let Some(next) = res.frame().await { + let frame = next.ok()?; + if let Some(chunk) = frame.data_ref() { + body.extend_from_slice(chunk) + } + } + + if res.headers().get(hyper::header::CONTENT_TYPE)? + != hyper::header::HeaderValue::from_static(ALLOWED_PURPOSES_CONTENT_TYPE) + { + return None; + } + + let allowed_purposes = parse_alpn_encoded(&body).ok()?; + if allowed_purposes.contains(&MAGIC_BIP77_PURPOSE.to_vec()) { + return Some(()); + } + + None + } + + /// Probes a target gateway by attempting to send a GET request. + async fn probe(&self, base_url: &GatewayUri) -> Policy { + // Create a GET request without a body + let req = hyper::Request::builder() + .method(hyper::Method::GET) + .uri(base_url.probe_url()) + .body(http_body_util::combinators::BoxBody::::new( + http_body_util::Empty::new().map_err(|_| { + panic!("infallible error type should never produce an actual error to map") + }), + )) + .expect("creating GET request must succeed"); + + let mut res = self.client.request(req).await; + + // opt-in is tracked via a separate mutable variable since it only + // occurs in the first sub-branch of this large conditional, which is + // largely concerned with determining the TTL + let mut bip77_allowed = false; + + let ttls = &self.ttl_config; + let ttl = match &mut res { + Ok(res) => { + // TODO handle Cache-Control + let status = res.status(); + + if status.is_success() { + bip77_allowed = Self::is_explicit_opt_in(res).await.is_some(); + + if bip77_allowed { + ttls.opt_in + } else { + ttls.http_2xx + } + } else if status == hyper::StatusCode::GATEWAY_TIMEOUT { + ttls.http_504_gateway_timeout + } else if status.is_client_error() { + // TODO handle Retry-After for 429 too many requests + ttls.http_4xx + } else if status.is_server_error() { + // TODO handle Retry-After for 503 service unavailable + ttls.http_5xx + } else { + ttls.default + } + } + Err(err) => { + if let Some(io_error) = + err.source().and_then(|source| source.downcast_ref::()) + { + match io_error.kind() { + ErrorKind::NotFound => ttls.dns, + ErrorKind::TimedOut => ttls.timedout, + ErrorKind::ConnectionReset => ttls.reset_by_peer, + _ => ttls.default, + } + } else { + ttls.default + } + } + }; + + Policy { bip77_allowed, expires: Instant::now() + ttl } + } + + pub(crate) async fn unavailable_for(&self) -> Duration { + let mut locked_map = self.gateways.write().await; + locked_map.no_capacity_for() + } +} + +fn parse_alpn_encoded(input: &[u8]) -> std::io::Result>> { + let mut input = input; + let mut output: Vec> = Vec::with_capacity(input.read_u16::()?.into()); + + while output.capacity() != output.len() { + let mut buf = vec![0u8; input.read_u8()?.into()]; + input.read_exact(&mut buf)?; + output.push(buf); + } + + if !input.is_empty() { + return Err(std::io::Error::new( + std::io::ErrorKind::UnexpectedEof, + "Unexpected trailing data", + )); + } + + Ok(output) +} + +#[derive(Debug)] +struct TTLConfig { + /// Explicit opt-in, defaults to LONG. + opt_in: Duration, + + // everything else is an opt-out + /// Any other 2xx response, for example ohttp-keys which indicate no + /// opt-in. Defaults to LONG to avoid spamming servers. + http_2xx: Duration, + /// Any 4xx response, for example 404. Defaults to LONG to avoid + /// spamming servers. + http_4xx: Duration, + /// TTL for 504 gateway timeout. Defaults to NONE assuming that is transient. + http_504_gateway_timeout: Duration, + /// Any other 5xx response, for example internal server error. Defaults to + /// SHORT. + http_5xx: Duration, + + // io errors, should be ephemeral + /// TTL for host not found. Defaults to NONE assuming host name resolution and/or DNS resolver cache negative results. + dns: Duration, + /// TTL for reset by peer errors. Defaults to NONE as that is transient. + reset_by_peer: Duration, + /// TTL for tcp timeout. Defaults to NONE as that is transient. + timedout: Duration, + + /// For other errors, default to SHORT enforce rudimentary rate limiting + default: Duration, +} + +/// Different probing results/conditions and the time to live when caching that +/// information. +impl Default for TTLConfig { + fn default() -> Self { + /// A week + const LONG: Duration = Duration::from_secs(7 * 24 * 60 * 60); + /// 5 seconds + const SHORT: Duration = Duration::from_secs(5); + /// 0 seconds + const NONE: Duration = Duration::ZERO; + + Self { + opt_in: LONG, + http_2xx: LONG, + http_4xx: LONG, + http_504_gateway_timeout: NONE, + http_5xx: SHORT, + dns: NONE, + reset_by_peer: NONE, + timedout: NONE, + default: SHORT, + } + } +} + +#[cfg(test)] +mod tests { + use std::str::FromStr; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + + use mockito::Server; + use tokio::time::advance; + + use super::*; + use crate::gateway_uri::RFC_9540_GATEWAY_PATH; + + const BIP77_OPT_IN_RESPONSE: &[u8] = b"\x00\x01\x2aBIP77 454403bb-9f7b-4385-b31f-acd2dae20b7e"; + const TIMESTEP: Duration = Duration::from_secs(1); // only used with advance() + const EPSILON: Duration = Duration::from_millis(1); // only used with advance() + + #[tokio::test(start_paused = true)] + async fn test_known_gateways() { + let mut db = KnownGateways { capacity: 1, ..Default::default() }; + + let url = GatewayUri::from_static("https://payjo.in"); + + assert!(db.has_capacity(), "known gateway set should be empty"); + assert!(db.no_capacity_for().is_zero(), "capacity should be available right now"); + assert!(db.get(&url).is_none(), "mock gateway should not yet be known"); + + let policy = Policy { bip77_allowed: true, expires: Instant::now() + TIMESTEP }; + + // see comment in implementation of insert(), arguably this should not + // be allowed as the state machine should start with inflight, but this + // behavior is simpler and given that's what's implemented it should be + // tested. + assert!(db.insert(&url, policy).is_some(), "insertion of gateway policy should succeed"); + if let Some(Status::Known(got)) = db.get(&url) { + assert_eq!(*got, policy, "initially inserted policy should be retrievable"); + } else { + panic!("initially inserted policy should be retrievable"); + } + + // No overwriting + assert!( + db.allocate_in_flight(&url).is_none(), + "allocating inflight future for known gateway should fail" + ); + assert!( + db.insert(&url, Policy { bip77_allowed: false, expires: Instant::now() + TIMESTEP }) + .is_none(), + "inserting a duplicate policy entry should fail" + ); + if let Some(Status::Known(got)) = db.get(&url) { + assert_eq!(*got, policy, "initially inserted policy should be retrievable"); + } else { + panic!("initially inserted policy should be retrievable"); + } + + // Pruning + assert!(!db.has_capacity(), "known gateway set should be at capacity"); + assert!( + !db.no_capacity_for().is_zero(), + "next capacity availability should be in the future" + ); + advance(TIMESTEP + EPSILON).await; + assert!(db.has_capacity(), "after waiting for expiry capacity should be available"); + assert!(db.no_capacity_for().is_zero(), "capacity should be available right now",); + + // Insert expired + assert!( + db.insert(&url, Policy { bip77_allowed: false, expires: Instant::now() }).is_some(), + "inserting an expired entry should not fail" + ); + assert!( + db.get(&url).is_some(), + "inserted expired entry should be retrievable in same instant" + ); + advance(EPSILON).await; + assert!(db.get(&url).is_none(), "inserted expired entry should not be retrievable"); + assert!( + db.has_capacity(), + "after inserting expired entry capacity should still be available" + ); + + let inflight = + db.allocate_in_flight(&url).expect("allocating inflight entry should succeed"); + + if let Some(Status::InFlight(got)) = db.get(&url) { + assert!(got.peek().is_none(), "inflight entry future should still be pending"); + + inflight.send(policy).expect("oneshot channel should accept a value"); + + assert_eq!( + got.clone().await.expect("inflight future should have been resolved"), + policy + ); + } else { + panic!("inflight entry should be retrievable"); + } + + // Upgrade in-flight to known entry... this too kind of violates encapsulation + assert!( + !db.has_capacity(), + "with an inflight entry, known gateway set should be at capacity" + ); + assert!( + db.insert(&url, Policy { bip77_allowed: true, expires: Instant::now() + TIMESTEP }) + .is_some(), + "inserting known entry to overwrite inflight one should succeed even at capacity" + ); + + // Test heap behavior + assert!(!db.has_capacity(), "after inserting known entry set should still be at capacity"); + assert!( + !db.no_capacity_for().is_zero(), + "and next capacity availability should be in the future" + ); + db.capacity = 2; + assert!(db.has_capacity(), "after raising limit, set should no longer be at capacity"); + assert!(db.no_capacity_for().is_zero(), "capacity should be available right now",); + + let url_2 = GatewayUri::from_static("https://payspl.it"); + + assert!(db.get(&url).is_some(), "previously inserted entry should still be in the set"); + assert!(db.get(&url_2).is_none(), "unknown entry should not be in the set"); + + assert!( + db.insert( + &url_2, + Policy { bip77_allowed: false, expires: Instant::now() + (2 * TIMESTEP) } + ) + .is_some(), + "inserting second entry should succeed" + ); + assert!(!db.has_capacity(), "after insertion gateway set should be at capacity"); + + assert!(db.get(&url).is_some(), "retrieving initially inserted entry should succeed"); + assert!(db.get(&url_2).is_some(), "retrieving second inserted entry should succeed"); + + advance(TIMESTEP + EPSILON).await; + + assert!(db.get(&url).is_none(), "after delay initially inserted entry should have expired"); + assert!(db.get(&url_2).is_some(), "second inserted entry should still be retrievable"); + + assert!(db.has_capacity(), "after expiry, capacity should be available"); + db.capacity = 1; + assert!( + !db.has_capacity(), + "after reducing the limit capacity should no longer be available" + ); + + advance(TIMESTEP + EPSILON).await; + assert!( + db.has_capacity(), + "after waiting for 2nd entry to expire, capacity should be available again" + ); + assert!(db.no_capacity_for().is_zero(), "capacity should be available right now"); + + assert!(db.get(&url).is_none(), "initial entry should have expired"); + assert!(db.get(&url_2).is_none(), "second entry should have expired"); + } + + #[tokio::test] + async fn test_mock_opt_in() { + let mut server = Server::new_async().await; + let url = + GatewayUri::from_str(&server.url()).expect("must be able to parse mock server URL"); + + let prober = Prober::default(); + + let mock_opt_in = server + .mock("GET", RFC_9540_GATEWAY_PATH) + .match_query(mockito::Matcher::Regex("^allowed_purposes$".into())) + .with_header(hyper::header::CONTENT_TYPE.as_str(), ALLOWED_PURPOSES_CONTENT_TYPE) + .with_body(BIP77_OPT_IN_RESPONSE) + .create(); + + // test happy path + let status = prober.check_opt_in(&url).await.expect("probing must succeed"); + assert!(status.bip77_allowed, "mock gateway opt-in should have been detected"); + mock_opt_in.assert(); + drop(mock_opt_in); + + // test cached result, mockit server will cause failure if another GET query is sent + let status = prober.check_opt_in(&url).await.expect("second probe must succeed"); + assert!(status.bip77_allowed, "gateway opt-in should be cached"); + } + + #[tokio::test] + async fn test_assert_opt_in() { + // no mock handlers, so any request should fail + let server = Server::new_async().await; + let url = + GatewayUri::from_str(&server.url()).expect("must be able to parse mock server URL"); + + let prober = Prober::default(); + + prober.assert_opt_in(&url).await.expect("asserting opt in should succeed"); + assert!( + prober.assert_opt_in(&url).await.is_none(), + "asserting opt in a second time should fail" + ); + + // test happy path + let status = prober.check_opt_in(&url).await.expect("probing must succeed"); + assert!(status.bip77_allowed, "asserte opt-in should be cached"); + } + + #[tokio::test] + async fn test_mock_no_opt_in() { + let mut server = Server::new_async().await; + let url = + GatewayUri::from_str(&server.url()).expect("must be able to parse mock server URL"); + + let prober = Prober::default(); + + let mock_only_rfc_9540 = server + .mock("GET", RFC_9540_GATEWAY_PATH) + .match_query(mockito::Matcher::Regex("^allowed_purposes$".into())) + .with_header(hyper::header::CONTENT_TYPE.as_str(), "application/ohttp-keys") + .with_body(b"\x00") // note: not actually a valid ohttp-keys encoding + .create(); + + let status = prober.check_opt_in(&url).await.expect("probing must succeed"); + mock_only_rfc_9540.assert(); + assert!( + !status.bip77_allowed, + "RFC 9540 gateway which doesn't signal should not be considered opted-in" + ); + } + + #[tokio::test] + async fn test_mock_404() { + let mut server = Server::new_async().await; + let url = + GatewayUri::from_str(&server.url()).expect("must be able to parse mock server URL"); + + let prober = Prober::default(); + + let mock_not_found = server + .mock("GET", RFC_9540_GATEWAY_PATH) + .match_query(mockito::Matcher::Regex("^allowed_purposes$".into())) + .with_status(404) + .with_body("not found") + .create(); + + let status = prober.check_opt_in(&url).await.expect("probing must succeed"); + mock_not_found.assert(); + assert!(!status.bip77_allowed, "non-existent gateway should not be considered opt-in"); + } + + #[tokio::test] + async fn test_inflight_deduplication() { + let mut server = Server::new_async().await; + let url = + GatewayUri::from_str(&server.url()).expect("must be able to parse mock server URL"); + + let prober = Prober::default(); + + let counter = Arc::new(Mutex::new(0)); + let condvar = Arc::new(std::sync::Condvar::new()); + let cvmutex = Arc::new(Mutex::new(false)); + + let mock_delayed = { + let counter = counter.clone(); + let condvar = condvar.clone(); + let cvmutex = cvmutex.clone(); + + server + .mock("GET", RFC_9540_GATEWAY_PATH) + .match_query(mockito::Matcher::Regex("^allowed_purposes$".into())) + .with_header(hyper::header::CONTENT_TYPE.as_str(), ALLOWED_PURPOSES_CONTENT_TYPE) + .with_chunked_body(move |w| { + // track how many requests have been received + let mut c = counter.lock().unwrap(); + *c += 1; + + // wait until both probe tasks were started + let mut guard = cvmutex.lock().unwrap(); + while !*guard { + guard = condvar.wait(guard).unwrap(); + } + + w.write_all(BIP77_OPT_IN_RESPONSE) + }) + .create() + }; + + let check_a = prober.check_opt_in(&url); + let check_b = prober.check_opt_in(&url); + + let ensure_both_inflight = async { + // wait until both probe tasks are in flight + loop { + std::thread::yield_now(); + let mut guard = prober.gateways.write().await; + if let Some(Status::InFlight(fut)) = guard.get(&url) { + if fut.strong_count().expect("inflight future should not yet be resolved") == 2 + { + break; + } + } + + // avoid spinlock, let probe tasks make progress + tokio::time::sleep(Duration::from_micros(10)).await; + } + + // release the server + { + let mut guard = cvmutex.lock().unwrap(); + *guard = true; // Set the condition to true + } + condvar.notify_one(); + }; + + let (a, b, _) = tokio::join!(check_a, check_b, ensure_both_inflight); + + mock_delayed.assert(); + assert!( + a.expect("probe must succeed").bip77_allowed, + "first concurrent request should detect opt-in" + ); + assert!( + b.expect("probe must succeed").bip77_allowed, + "second concurrent request should detect opt-in" + ); + assert_eq!(*counter.lock().unwrap(), 1, "requests should have been deduplicated"); + } + + #[test] + fn test_parse_alpn_encoded() { + let result = parse_alpn_encoded(b""); + assert!(result.is_err(), "empty string should not be valid"); + + let result = parse_alpn_encoded(b"\x00"); + assert!(result.is_err(), "null byte should not be valid"); + + let result = parse_alpn_encoded(b"\x00\x00"); + assert_eq!( + result.expect("a list of length 0 should parse without error").len(), + 0, + "empty list should have len 0" + ); + + let result = parse_alpn_encoded(b"\x00\x00\x00"); + assert!(result.is_err(), "trailing data should be invalid"); + + let result = parse_alpn_encoded(b"\x00\x01"); + assert!(result.is_err(), "a truncated list of length 1 should be invalid"); + + let result = parse_alpn_encoded(b"\x00\x01\x00") + .expect("a list with one empty element should parse without error"); + assert_eq!(result.len(), 1, "should contain 1 element"); + assert_eq!(result[0].len(), 0, "the single element should be of length 0"); + + let result = parse_alpn_encoded(b"\x00\x01\x01a") + .expect("a list with one element of length 1 should parse without error"); + assert_eq!(result.len(), 1, "should contain 1 element"); + assert_eq!(result[0].len(), 1, "element length should be 1"); + assert_eq!(result[0][0], b'a', "the element value should be the single byte 'a'"); + + let result = parse_alpn_encoded(b"\x00\x02\x01\x00\x00") + .expect("list with two elements should parse correctly"); + assert_eq!(result.len(), 2, "two element list should be valid"); + assert_eq!(result[0].len(), 1, "the first element should be a 1 byte long"); + assert_eq!(result[0][0], 0, "the first element should be a null byte"); + assert_eq!(result[1].len(), 0, "the second element should be empty"); + + let result = parse_alpn_encoded(BIP77_OPT_IN_RESPONSE) + .expect("stock BIP 77 opt in response should parse correctly"); + assert_eq!(result.len(), 1, "pre canned BIP 77 opt-in response should have 1 element"); + assert_eq!( + result[0], MAGIC_BIP77_PURPOSE, + "the element should be the bip77 opt-in magic string" + ); + } +} diff --git a/ohttp-relay/src/gateway_uri.rs b/ohttp-relay/src/gateway_uri.rs new file mode 100644 index 000000000..6bb6894ec --- /dev/null +++ b/ohttp-relay/src/gateway_uri.rs @@ -0,0 +1,176 @@ +use std::str::FromStr; + +use http::uri::{Authority, Scheme}; +use http::Uri; + +use crate::error::BoxError; + +pub(crate) const RFC_9540_GATEWAY_PATH: &str = "/.well-known/ohttp-gateway"; +const ALLOWED_PURPOSES_PATH_AND_QUERY: &str = "/.well-known/ohttp-gateway?allowed_purposes"; + +/// A normalized gateway origin URI with a default port if none is specified. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct GatewayUri { + scheme: Scheme, + authority: Authority, +} + +impl GatewayUri { + pub fn new(scheme: Scheme, authority: Authority) -> Result { + let default_port = if scheme == Scheme::HTTP { + 80 + } else if scheme == Scheme::HTTPS { + 443 + } else { + return Err("Unsupported URI scheme".into()); + }; + + // If no explicit port is provided, make the default one explicit + let mut authority = authority; + if authority.port().is_none() { + authority = Authority::from_str(&format!("{}:{}", authority.host(), default_port)) + .expect("setting default port must succeed"); + } + + Ok(Self { scheme, authority }) + } + + pub fn from_static(string: &'static str) -> Self { + Uri::from_static(string) + .try_into() + .expect("gateway URI must consist of a scheme and authority only") + } + + fn to_uri_builder(&self) -> http::uri::Builder { + Uri::builder().scheme(self.scheme.clone()).authority(self.authority.clone()) + } + + pub fn to_uri(&self) -> Uri { + self.to_uri_builder() + .path_and_query("/") + .build() + .expect("Building Uri from scheme and authority must succeed") + } + + pub fn rfc_9540_url(&self) -> Uri { + self.to_uri_builder() + .path_and_query(RFC_9540_GATEWAY_PATH) + .build() + .expect("building RFC 9540 uri from scheme and authority must succeed") + } + + pub fn probe_url(&self) -> Uri { + self.to_uri_builder() + .path_and_query(ALLOWED_PURPOSES_PATH_AND_QUERY) + .build() + .expect("building RFC 9540 uri from scheme and authority must succeed") + } + + pub async fn to_socket_addr(&self) -> std::io::Result> { + Ok(self.to_socket_addrs().await?.next()) + } + + pub async fn to_socket_addrs( + &self, + ) -> std::io::Result> { + tokio::net::lookup_host(self.authority.to_string()).await + } +} + +impl From for Uri { + fn from(val: GatewayUri) -> Uri { val.to_uri() } +} + +impl TryFrom for GatewayUri { + type Error = BoxError; + + fn try_from(uri: Uri) -> Result { + let parts = uri.into_parts(); + + if let Some(pq) = parts.path_and_query { + if pq.as_str() != "/" { + return Err("URI must not contain path or query".into()); + } + } + + let scheme = parts.scheme.ok_or::("URI must have a scheme".into())?; + let authority = parts.authority.ok_or::("URI must have an authority".into())?; + + Self::new(scheme, authority) + } +} + +impl From for GatewayUri { + fn from(authority: Authority) -> Self { + Self::new(Scheme::HTTPS, authority) + .expect("constructing GatewayUri with valid authority must succeed") + } +} + +impl FromStr for GatewayUri { + type Err = BoxError; + fn from_str(string: &str) -> Result { Uri::from_str(string)?.try_into() } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn conversion() { + let uri_with_port = Uri::from_static("http://payjo.in:80"); + let gateway_uri = GatewayUri::try_from(uri_with_port.clone()) + .expect("should be a valid gateway base URI"); + assert_eq!(gateway_uri.to_uri(), uri_with_port, "uri should be the same as input"); + + let uri_without_port = Uri::from_static("http://payjo.in"); + let gateway_uri = + GatewayUri::try_from(uri_without_port).expect("should be a valid gateway base URI"); + + let uri: Uri = gateway_uri.clone().into(); + assert_eq!(uri, uri_with_port, "uri should be canonicalized to contain port"); + + assert_eq!( + gateway_uri.rfc_9540_url(), + Uri::from_static("http://payjo.in:80/.well-known/ohttp-gateway"), + "uri should be canonicalized to contain port" + ); + } + + #[test] + fn default_port() { + let uri = GatewayUri::from_static("http://payjo.in"); + assert_eq!( + uri.authority.port_u16(), + Some(80), + "default port should be made explicit for http scheme" + ); + + let uri = GatewayUri::from_static("https://payjo.in"); + assert_eq!( + uri.authority.port_u16(), + Some(443), + "default port should be made explicit for https scheme" + ); + + let uri = GatewayUri::from_static("https://payjo.in:80"); + assert_eq!(uri.authority.port_u16(), Some(80), "explicit port should override default"); + + let uri = GatewayUri::from_static("http://payjo.in:1234"); + assert_eq!(uri.authority.port_u16(), Some(1234), "explicit port should override default"); + } + + #[test] + fn invalid_uris() { + assert!(GatewayUri::from_str("payjo.in").is_err(), "scheme is mandatory"); + + assert!(GatewayUri::from_str("/index.html").is_err(), "url must be absolute"); + + assert!( + GatewayUri::from_str("ftp://payjo.in").is_err(), + "only http and https scheme should be allowed" + ); + + assert!(GatewayUri::from_str("http://payjo.in/blah").is_err(), "url must not contain path"); + } +} diff --git a/ohttp-relay/src/lib.rs b/ohttp-relay/src/lib.rs new file mode 100644 index 000000000..1af4828f7 --- /dev/null +++ b/ohttp-relay/src/lib.rs @@ -0,0 +1,305 @@ +use std::net::SocketAddr; +use std::str::FromStr; +use std::sync::Arc; + +pub(crate) use gateway_prober::Prober; +pub use gateway_uri::GatewayUri; +use http::uri::Authority; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Empty, Full}; +use hyper::body::{Bytes, Incoming}; +use hyper::header::{ + HeaderValue, ACCESS_CONTROL_ALLOW_HEADERS, ACCESS_CONTROL_ALLOW_METHODS, + ACCESS_CONTROL_ALLOW_ORIGIN, CONTENT_LENGTH, CONTENT_TYPE, +}; +use hyper::server::conn::http1; +use hyper::service::service_fn; +use hyper::{Method, Request, Response}; +use hyper_rustls::builderstates::WantsSchemes; +use hyper_rustls::{HttpsConnector, HttpsConnectorBuilder}; +use hyper_util::client::legacy::connect::HttpConnector; +use hyper_util::client::legacy::Client; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::net::{TcpListener, UnixListener}; +use tokio_util::net::Listener; +use tracing::{error, info, instrument}; + +pub mod error; +#[cfg(not(feature = "_test-util"))] +mod gateway_prober; +#[cfg(feature = "_test-util")] +pub mod gateway_prober; +mod gateway_uri; +use crate::error::{BoxError, Error}; + +#[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] +pub mod bootstrap; + +pub const DEFAULT_PORT: u16 = 3000; +pub const OHTTP_RELAY_HOST: HeaderValue = HeaderValue::from_static("0.0.0.0"); +pub const EXPECTED_MEDIA_TYPE: HeaderValue = HeaderValue::from_static("message/ohttp-req"); + +#[instrument] +pub async fn listen_tcp( + port: u16, + gateway_origin: GatewayUri, +) -> Result>, BoxError> { + let addr = SocketAddr::from(([0, 0, 0, 0], port)); + let listener = TcpListener::bind(addr).await?; + println!("OHTTP relay listening on tcp://{}", addr); + ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await +} + +#[instrument] +pub async fn listen_socket( + socket_path: &str, + gateway_origin: GatewayUri, +) -> Result>, BoxError> { + let listener = UnixListener::bind(socket_path)?; + info!("OHTTP relay listening on socket: {}", socket_path); + ohttp_relay(listener, RelayConfig::new_with_default_client(gateway_origin)).await +} + +#[cfg(feature = "_test-util")] +pub async fn listen_tcp_on_free_port( + default_gateway: GatewayUri, + root_store: rustls::RootCertStore, +) -> Result<(u16, tokio::task::JoinHandle>), BoxError> { + let listener = tokio::net::TcpListener::bind("[::]:0").await?; + let port = listener.local_addr()?.port(); + println!("OHTTP relay binding to port {}", listener.local_addr()?); + let config = RelayConfig::new(default_gateway, root_store); + let handle = ohttp_relay(listener, config).await?; + Ok((port, handle)) +} + +#[derive(Debug)] +struct RelayConfig { + default_gateway: GatewayUri, + client: HttpClient, + prober: Prober, +} + +impl RelayConfig { + fn new_with_default_client(default_gateway: GatewayUri) -> Self { + Self::new(default_gateway, HttpClient::default()) + } + + fn new(default_gateway: GatewayUri, into_client: impl Into) -> Self { + let client = into_client.into(); + let prober = Prober::new_with_client(client.clone()); + RelayConfig { default_gateway, client, prober } + } +} + +#[derive(Debug, Clone)] +pub(crate) struct HttpClient( + hyper_util::client::legacy::Client, BoxBody>, +); + +impl std::ops::Deref for HttpClient { + type Target = hyper_util::client::legacy::Client< + HttpsConnector, + BoxBody, + >; + fn deref(&self) -> &Self::Target { &self.0 } +} + +impl From> for HttpClient { + fn from(builder: HttpsConnectorBuilder) -> Self { + let https = builder.https_or_http().enable_http1().build(); + Self(Client::builder(TokioExecutor::new()).build(https)) + } +} + +impl Default for HttpClient { + fn default() -> Self { HttpsConnectorBuilder::new().with_webpki_roots().into() } +} + +impl From for HttpClient { + fn from(root_store: rustls::RootCertStore) -> Self { + HttpsConnectorBuilder::new() + .with_tls_config( + rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(), + ) + .into() + } +} + +#[instrument(skip(listener))] +async fn ohttp_relay( + mut listener: L, + config: RelayConfig, +) -> Result>, BoxError> +where + L: Listener + Unpin + Send + 'static, + L::Io: AsyncRead + AsyncWrite + Unpin + Send + 'static, +{ + config.prober.assert_opt_in(&config.default_gateway).await; + + let config = Arc::new(config); + + let handle = tokio::spawn(async move { + while let Ok((stream, _)) = listener.accept().await { + let config = config.clone(); + let io = TokioIo::new(stream); + tokio::spawn(async move { + if let Err(err) = http1::Builder::new() + .serve_connection(io, service_fn(|req| serve_ohttp_relay(req, &config))) + .with_upgrades() + .await + { + error!("Error serving connection: {:?}", err); + } + }); + } + Ok(()) + }); + + Ok(handle) +} + +#[instrument] +async fn serve_ohttp_relay( + req: Request, + config: &RelayConfig, +) -> Result>, hyper::Error> { + let mut res = match (req.method(), req.uri().path()) { + (&Method::OPTIONS, _) => Ok(handle_preflight()), + (&Method::GET, "/health") => Ok(health_check().await), + (&Method::POST, _) => match parse_gateway_uri(&req, config).await { + Ok(gateway_uri) => handle_ohttp_relay(req, config, gateway_uri).await, + Err(e) => Err(e), + }, + #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] + (&Method::GET, _) | (&Method::CONNECT, _) => match parse_gateway_uri(&req, config).await { + Ok(gateway_uri) => crate::bootstrap::handle_ohttp_keys(req, gateway_uri).await, + Err(e) => Err(e), + }, + _ => Err(Error::NotFound), + } + .unwrap_or_else(|e| e.to_response()); + res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + Ok(res) +} + +async fn parse_gateway_uri( + req: &Request, + config: &RelayConfig, +) -> Result { + // for POST and GET (websockets), the gateway URI is provided in the path + // for CONNECT requests, just an authority is provided, and we assume HTTPS + let gateway_uri = match req.method() { + &Method::CONNECT => req.uri().authority().cloned().map(GatewayUri::from), + _ => parse_gateway_uri_from_path(req.uri().path(), &config.default_gateway).ok(), + } + .ok_or_else(|| Error::BadRequest("Invalid gateway".to_string()))?; + + let policy = match config.prober.check_opt_in(&gateway_uri).await { + Some(policy) => Ok(policy), + None => Err(Error::Unavailable(config.prober.unavailable_for().await)), + }?; + + if policy.bip77_allowed { + Ok(gateway_uri) + } else { + // TODO Cache-Control header for error based on policy.expires + // is not found the right error? maybe forbidden or bad gateway? + // prober policy judgement can be an enum instead of a bool to + // distinguish 4xx vs. 5xx failures, 4xx being an explicit opt out and + // 5xx for IO errors etc + Err(Error::NotFound) + } +} + +fn parse_gateway_uri_from_path(path: &str, default: &GatewayUri) -> Result { + if path.is_empty() || path == "/" { + return Ok(default.clone()); + } + + let path = &path[1..]; + + if "http://" == &path[..7] || "https://" == &path[..8] { + GatewayUri::from_str(path) + } else { + Ok(Authority::from_str(path)?.into()) + } +} + +fn handle_preflight() -> Response> { + let mut res = Response::new(empty()); + *res.status_mut() = hyper::StatusCode::NO_CONTENT; + res.headers_mut().insert(ACCESS_CONTROL_ALLOW_ORIGIN, HeaderValue::from_static("*")); + res.headers_mut().insert( + ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static("CONNECT, GET, OPTIONS, POST"), + ); + res.headers_mut().insert( + ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static("Content-Type, Content-Length"), + ); + res +} + +async fn health_check() -> Response> { Response::new(empty()) } + +#[instrument] +async fn handle_ohttp_relay( + req: Request, + config: &RelayConfig, + gateway: GatewayUri, +) -> Result>, Error> { + let fwd_req = into_forward_req(req, gateway)?; + forward_request(fwd_req, config).await.map(|res| { + let (parts, body) = res.into_parts(); + let boxed_body = BoxBody::new(body); + Response::from_parts(parts, boxed_body) + }) +} + +/// Convert an incoming request into a request to forward to the target gateway server. +#[instrument] +fn into_forward_req( + req: Request, + gateway_origin: GatewayUri, +) -> Result>, Error> { + let (head, body) = req.into_parts(); + + if head.method != hyper::Method::POST { + return Err(Error::MethodNotAllowed); + } + + if head.headers.get(CONTENT_TYPE) != Some(&EXPECTED_MEDIA_TYPE) { + return Err(Error::UnsupportedMediaType); + } + + let mut builder = Request::builder() + .method(hyper::Method::POST) + .uri(gateway_origin.rfc_9540_url()) + .header(CONTENT_TYPE, EXPECTED_MEDIA_TYPE); + + if let Some(content_length) = head.headers.get(CONTENT_LENGTH) { + builder = builder.header(CONTENT_LENGTH, content_length); + } + + builder.body(BoxBody::new(body)).map_err(|e| Error::InternalServerError(Box::new(e))) +} + +#[instrument] +async fn forward_request( + req: Request>, + config: &RelayConfig, +) -> Result, Error> { + config.client.request(req).await.map_err(|_| Error::BadGateway) +} + +pub(crate) fn empty() -> BoxBody { + Empty::::new().map_err(|never| match never {}).boxed() +} + +pub(crate) fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()).map_err(|never| match never {}).boxed() +} diff --git a/ohttp-relay/src/main.rs b/ohttp-relay/src/main.rs new file mode 100644 index 000000000..ea7811e7e --- /dev/null +++ b/ohttp-relay/src/main.rs @@ -0,0 +1,41 @@ +use std::str::FromStr; + +use ohttp_relay::{GatewayUri, DEFAULT_PORT}; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; +use tracing_subscriber::{fmt, EnvFilter}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install default crypto provider"); + + init_tracing(); + let port_env = std::env::var("PORT"); + let unix_socket_env = std::env::var("UNIX_SOCKET"); + let gateway_origin_str = std::env::var("GATEWAY_ORIGIN").expect("GATEWAY_ORIGIN is required"); + let gateway_origin = + GatewayUri::from_str(&gateway_origin_str).expect("Invalid GATEWAY_ORIGIN URI"); + + match (port_env, unix_socket_env) { + (Ok(_), Ok(_)) => panic!( + "Both PORT and UNIX_SOCKET environment variables are set. Please specify only one." + ), + (Err(_), Ok(unix_socket_path)) => + ohttp_relay::listen_socket(&unix_socket_path, gateway_origin).await?, + (Ok(port_str), Err(_)) => { + let port: u16 = port_str.parse().expect("Invalid PORT"); + ohttp_relay::listen_tcp(port, gateway_origin).await? + } + (Err(_), Err(_)) => ohttp_relay::listen_tcp(DEFAULT_PORT, gateway_origin).await?, + } + .await? +} + +fn init_tracing() { + tracing_subscriber::registry() + .with(EnvFilter::from_default_env()) + .with(fmt::layer().with_target(true)) // Log the target (usually the module path and function name) + .init(); +} diff --git a/ohttp-relay/tests/integration.rs b/ohttp-relay/tests/integration.rs new file mode 100644 index 000000000..b00c779c6 --- /dev/null +++ b/ohttp-relay/tests/integration.rs @@ -0,0 +1,528 @@ +#[cfg(test)] +#[cfg(feature = "_test-util")] +mod integration { + use std::fs::File; + use std::io::Read; + use std::net::SocketAddr; + use std::path::PathBuf; + use std::str::FromStr; + + use hex::FromHex; + use http_body_util::combinators::BoxBody; + use http_body_util::{BodyExt, Full}; + use hyper::body::{Bytes, Incoming}; + use hyper::header::{HeaderValue, CONTENT_LENGTH, CONTENT_TYPE}; + use hyper::server::conn::http1; + use hyper::service::service_fn; + use hyper::{Request, Response}; + use hyper_rustls::HttpsConnectorBuilder; + use hyper_util::client::legacy::Client; + use hyper_util::rt::{TokioExecutor, TokioIo}; + use ohttp_relay::gateway_prober::{ALLOWED_PURPOSES_CONTENT_TYPE, MAGIC_BIP77_PURPOSE}; + use ohttp_relay::*; + use rcgen::Certificate; + use rustls::pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs8KeyDer}; + use tempfile::NamedTempFile; + use tokio::net::{TcpListener, TcpStream}; + use tokio::process::Command; + + static INIT: std::sync::Once = std::sync::Once::new(); + + fn init_crypto_provider() { + INIT.call_once(|| { + rustls::crypto::ring::default_provider() + .install_default() + .expect("Failed to install default crypto provider"); + }); + } + + const ENCAPSULATED_REQ: &str = "010020000100014b28f881333e7c164ffc499ad9796f877f4e1051ee6d31bad19dec96c208b4726374e469135906992e1268c594d2a10c695d858c40a026e7965e7d86b83dd440b2c0185204b4d63525"; + const ENCAPSULATED_RES: &str = + "c789e7151fcba46158ca84b04464910d86f9013e404feea014e7be4a441f234f857fbd"; + + /// See: https://www.ietf.org/rfc/rfc9458.html#name-complete-example-of-a-reque + #[tokio::test] + async fn test_request_response_tcp() { + init_crypto_provider(); + let gateway_port = find_free_port(); + let gateway = GatewayUri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); + + let nginx_cert = gen_localhost_cert(); + let nginx_cert_der = cert_to_cert_der(&nginx_cert); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(nginx_cert_der.clone()).unwrap(); + + let (relay_port, relay_handle) = listen_tcp_on_free_port(gateway.clone(), root_store) + .await + .expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); + + let n_http_port = find_free_port(); + let n_https_port = find_free_port(); + let _nginx = + start_nginx(n_http_port, n_https_port, format!("0.0.0.0:{}", relay_port), nginx_cert) + .await; + tokio::select! { + _ = example_gateway_http(gateway_port) => { + panic!("Gateway is long running"); + } + _ = relay_task => { + panic!("Relay is long running"); + } + _ = ohttp_req(n_https_port, nginx_cert_der, gateway) => {} + } + } + + #[tokio::test] + async fn test_request_response_socket() -> Result<(), Box> { + init_crypto_provider(); + let temp_dir = std::env::temp_dir(); + let socket_path = temp_dir.as_path().join("test.socket"); + + if socket_path.exists() { + std::fs::remove_file(&socket_path).expect("Failed to remove existing socket file"); + } + + let gateway_port = find_free_port(); + let gateway = GatewayUri::from_str(&format!("http://0.0.0.0:{}", gateway_port)).unwrap(); + let nginx_cert = gen_localhost_cert(); + let nginx_cert_der = cert_to_cert_der(&nginx_cert); + let socket_path_str = socket_path.to_str().unwrap(); + let relay_handle = listen_socket(socket_path_str, gateway.clone()) + .await + .expect("Failed to listen on socket"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); + let n_http_port = find_free_port(); + let n_https_port = find_free_port(); + let _nginx = + start_nginx(n_http_port, n_https_port, format!("unix:{}", socket_path_str), nginx_cert) + .await?; + tokio::select! { + _ = example_gateway_http(gateway_port) => { + panic!("Gateway is long running"); + } + _ = relay_task => { + panic!("Relay is long running"); + } + _ = ohttp_req(n_https_port, nginx_cert_der, gateway) => {} + } + Ok(()) + } + + async fn example_gateway_http(port: u16) -> Result<(), Box> { + example_gateway(port, |stream| { + tokio::spawn(async move { + let io = TokioIo::new(stream); + if let Err(err) = + http1::Builder::new().serve_connection(io, service_fn(handle_gateway)).await + { + println!("Failed to serve connection: {:?}", err); + } + }); + }) + .await + } + + async fn handle_gateway( + req: Request, + ) -> Result>, hyper::Error> { + let res = match (req.method(), req.uri().path(), req.uri().query()) { + (&hyper::Method::POST, "/.well-known/ohttp-gateway", _) => handle_ohttp_req(req).await, + #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] + (&hyper::Method::GET, "/.well-known/ohttp-gateway", None) => + bootstrap::handle_ohttp_keys(req).await, + (&hyper::Method::GET, "/.well-known/ohttp-gateway", Some("allowed_purposes")) => + handle_opt_in(req).await, + _ => panic!("Unexpected request: {} {}", req.method(), req.uri().path()), + } + .unwrap(); + Ok(res) + } + + async fn handle_ohttp_req( + _: Request, + ) -> Result>, hyper::Error> { + let mut res = Response::new(full(Vec::from_hex(ENCAPSULATED_RES).unwrap()).boxed()); + *res.status_mut() = hyper::StatusCode::OK; + res.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("message/ohttp-res")); + res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("35")); + Ok(res) + } + + async fn handle_opt_in( + _: Request, + ) -> Result>, hyper::Error> { + let mut res = Response::new(full([b"\x00\x01\x2a", MAGIC_BIP77_PURPOSE].concat())); + *res.status_mut() = hyper::StatusCode::OK; + res.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static(ALLOWED_PURPOSES_CONTENT_TYPE)); + res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("45")); + Ok(res) + } + + async fn ohttp_req(relay_port: u16, cert: CertificateDer<'static>, gateway: GatewayUri) -> () { + for gw_path in ["", &gateway.to_uri().to_string()] { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let mut req = Request::new(full(Vec::from_hex(ENCAPSULATED_REQ).unwrap()).boxed()); + *req.method_mut() = hyper::Method::POST; + *req.uri_mut() = format!("https://0.0.0.0:{}/{}", relay_port, gw_path).parse().unwrap(); + req.headers_mut().insert(CONTENT_TYPE, HeaderValue::from_static("message/ohttp-req")); + req.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("78")); + + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(cert.clone()).unwrap(); + + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let https = HttpsConnectorBuilder::new() + .with_tls_config(config) + .https_or_http() + .enable_http1() + .build(); + let client = Client::builder(TokioExecutor::new()).build(https); + let res = client.request(req).await.unwrap(); + assert_eq!(res.status(), hyper::StatusCode::OK); + assert_eq!( + res.headers().get(CONTENT_TYPE), + Some(&HeaderValue::from_static("message/ohttp-res")) + ); + assert_eq!(res.headers().get(CONTENT_LENGTH), Some(&HeaderValue::from_static("35"))); + } + } + + async fn example_gateway(port: u16, handle_conn: F) -> Result<(), Box> + where + F: Fn(TcpStream) + Clone + Send + Sync + 'static, + { + let addr = SocketAddr::from(([127, 0, 0, 1], port)); + let listener = TcpListener::bind(addr).await?; + println!("Gateway listening on port {}", port); + + loop { + let (stream, _) = listener.accept().await?; + let handle_conn = handle_conn.clone(); + + tokio::task::spawn(async move { + handle_conn(stream); + }); + } + } + + fn find_free_port() -> u16 { + let listener = std::net::TcpListener::bind("0.0.0.0:0").unwrap(); + listener.local_addr().unwrap().port() + } + + #[cfg(any(feature = "connect-bootstrap", feature = "ws-bootstrap"))] + mod bootstrap { + use std::future::Future; + use std::io::Write; + use std::pin::Pin; + use std::sync::Arc; + + use rustls::pki_types::{self, CertificateDer}; + use rustls::ServerConfig; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + use tokio_rustls::{TlsAcceptor, TlsConnector}; + + use super::*; + + const OHTTP_KEYS: &str = "01002031e1f05a740102115220e9af918f738674aec95f54db6e04eb705aae8e79815500080001000100010003"; + + #[cfg(feature = "ws-bootstrap")] + mod ws_bootstrap { + use tokio_tungstenite::connect_async; + + use super::*; + + #[tokio::test] + async fn test_ws_bootstrap() { + init_crypto_provider(); + test_bootstrap(|relay_port, gateway, cert| { + Box::pin(ohttp_keys_ws_client(relay_port, gateway.clone(), cert)) + }) + .await; + } + + async fn ohttp_keys_ws_client( + relay_port: u16, + gateway: GatewayUri, + cert: CertificateDer<'_>, + ) { + use ohttp_relay::bootstrap::ws::WsIo; + + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(cert).unwrap(); + let config = tokio_rustls::rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + + let (ws_stream, _res) = + connect_async(format!("ws://0.0.0.0:{}/{}", relay_port, gateway.to_uri())) + .await + .expect("Failed to connect"); + println!("Connected to ws"); + let ws_io = WsIo::new(ws_stream); + let connector = TlsConnector::from(Arc::new(config)); + let domain = pki_types::ServerName::try_from("0.0.0.0") + .map_err(|_| { + std::io::Error::new(std::io::ErrorKind::InvalidInput, "invalid dnsname") + }) + .unwrap() + .to_owned(); + let mut tls_stream = connector.connect(domain, ws_io).await.unwrap(); + + let content = + b"GET /.well-known/ohttp-gateway HTTP/1.1\r\nHost: 0.0.0.0\r\nConnection: close\r\n\r\n"; + tls_stream.write_all(content).await.unwrap(); + tls_stream.flush().await.unwrap(); + let mut plaintext = Vec::new(); + let _ = tls_stream.read_to_end(&mut plaintext).await.unwrap(); + std::io::stdout().write_all(&plaintext).unwrap(); + } + } + + #[cfg(feature = "connect-bootstrap")] + mod connect_bootstrap { + use super::*; + + #[tokio::test] + async fn test_connect_bootstrap() { + init_crypto_provider(); + test_bootstrap(|relay_port, gateway, cert| { + Box::pin(ohttp_keys_connect_client(relay_port, gateway.clone(), cert)) + }) + .await; + } + + async fn ohttp_keys_connect_client( + relay_port: u16, + gateway: GatewayUri, + cert: CertificateDer<'_>, + ) { + tokio::time::sleep(std::time::Duration::from_secs(1)).await; + let client = reqwest::Client::builder() + .use_rustls_tls() + .tls_built_in_root_certs(false) + .add_root_certificate( + reqwest::Certificate::from_der(cert.as_ref()).expect("invalid cert der"), + ) + .proxy( + reqwest::Proxy::http(format!("http://0.0.0.0:{}", relay_port)) + .expect("invalid proxy"), + ) + .build() + .expect("failed building reqwest client"); + let url = gateway.rfc_9540_url(); + println!("gateway for proxy: {:?}", url); + let res = client.get(url.to_string()).send().await.unwrap(); + assert_eq!(res.status(), 200); + assert_eq!(res.headers().get("content-type").unwrap(), "application/ohttp-keys"); + assert_eq!(res.headers().get("content-length").unwrap(), "45"); + } + } + + async fn test_bootstrap(client_fn: F) + where + F: FnOnce( + u16, + &GatewayUri, + CertificateDer<'static>, + ) -> Pin>>, + { + let gateway_port = find_free_port(); + let gateway = + GatewayUri::from_str(&format!("https://0.0.0.0:{}", gateway_port)).unwrap(); + let nginx_cert = gen_localhost_cert(); + let nginx_cert_der = cert_to_cert_der(&nginx_cert); + let gateway_cert = gen_localhost_cert(); + let gateway_cert_der = cert_to_cert_der(&gateway_cert); + let mut root_store = rustls::RootCertStore::empty(); + root_store.add(gateway_cert_der.clone()).unwrap(); + root_store.add(nginx_cert_der).unwrap(); + let (relay_port, relay_handle) = listen_tcp_on_free_port(gateway.clone(), root_store) + .await + .expect("Failed to listen on free port"); + let relay_task = tokio::spawn(async move { + if let Err(e) = relay_handle.await { + eprintln!("Relay failed: {}", e); + } + }); + let n_http_port = find_free_port(); + let n_https_port = find_free_port(); + let _nginx = start_nginx( + n_http_port, + n_https_port, + format!("0.0.0.0:{}", relay_port), + nginx_cert, + ) + .await; + tokio::select! { + _ = example_gateway_https(gateway_port, gateway_cert) => { + panic!("Gateway is long running"); + } + _ = relay_task => { + panic!("Relay is long running"); + } + _ = client_fn(n_http_port, &gateway, gateway_cert_der) => {} + } + } + + async fn example_gateway_https( + port: u16, + cert: Certificate, + ) -> Result<(), Box> { + let acceptor = Arc::new(build_tls_acceptor(cert)); + + example_gateway(port, move |stream| { + let acceptor = acceptor.clone(); + tokio::spawn(async move { + let stream = acceptor.accept(stream).await.expect("TLS error"); + let io = TokioIo::new(stream); + if let Err(err) = + http1::Builder::new().serve_connection(io, service_fn(handle_gateway)).await + { + println!("Failed to serve connection: {:?}", err); + } + }); + }) + .await + } + + pub(crate) async fn handle_ohttp_keys( + _: Request, + ) -> Result>, hyper::Error> { + let mut res = Response::new(full(Vec::from_hex(OHTTP_KEYS).unwrap()).boxed()); + *res.status_mut() = hyper::StatusCode::OK; + res.headers_mut() + .insert(CONTENT_TYPE, HeaderValue::from_static("application/ohttp-keys")); + res.headers_mut().insert(CONTENT_LENGTH, HeaderValue::from_static("45")); + Ok(res) + } + + fn build_tls_acceptor(cert: Certificate) -> TlsAcceptor { + let (key, cert) = cert_to_key_cert_der(cert); + let server_config = ServerConfig::builder() + .with_no_client_auth() + .with_single_cert(vec![cert], key) + .unwrap(); + tokio_rustls::TlsAcceptor::from(Arc::new(server_config)) + } + } + + fn gen_localhost_cert() -> Certificate { + rcgen::generate_simple_self_signed(vec!["0.0.0.0".to_string()]).unwrap() + } + + fn cert_to_key_cert_der( + cert: Certificate, + ) -> (PrivateKeyDer<'static>, CertificateDer<'static>) { + let key = PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(cert.serialize_private_key_der())); + let cert = CertificateDer::from(cert.serialize_der().unwrap()); + (key, cert) + } + + fn cert_to_cert_der(cert: &Certificate) -> CertificateDer<'static> { + CertificateDer::from(cert.serialize_der().unwrap()) + } + + struct NginxProcess { + _child: tokio::process::Child, + config_path: PathBuf, + } + + impl Drop for NginxProcess { + fn drop(&mut self) { + // NGINX spawns child processes. Gracefully shut them all down. + let _ = std::process::Command::new("nginx") + .arg("-s") + .arg("stop") + .arg("-c") + .arg(self.config_path.as_os_str()) + .status(); + } + } + + async fn start_nginx( + n_http_port: u16, + n_https_port: u16, + proxy_pass: String, + cert: Certificate, + ) -> Result> { + use std::io::Write; + + let temp_dir = std::env::var("TMPDIR").unwrap_or_else(|_| "/tmp".into()); // Use Nix's TMPDIR + let unique_suffix = uuid::Uuid::new_v4().to_string(); // Ensures uniqueness + + let error_log_path = format!("{}/nginx_error_{}.log", temp_dir, unique_suffix); + let pid_path = format!("{}/nginx_{}.pid", temp_dir, unique_suffix); + + let cert_path = format!("{}/cert_{}.pem", temp_dir, unique_suffix); + std::fs::write(&cert_path, cert.serialize_pem().unwrap()) + .expect("Failed to write gateway cert"); + + let key_path = format!("{}/key_{}.pem", temp_dir, unique_suffix); + std::fs::write(&key_path, cert.serialize_private_key_pem()) + .expect("Failed to write gateway key"); + + let template_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("nginx.conf.template") + .canonicalize() + .unwrap(); + let mut template_file = File::open(template_path).expect("Failed to open template file"); + let mut template_content = String::new(); + template_file.read_to_string(&mut template_content).expect("Failed to read template file"); + + let nginx_conf = template_content + .replace("{{error_log_path}}", &error_log_path.to_string()) + .replace("{{pid_path}}", &pid_path.to_string()) + .replace("{{http_port}}", &n_http_port.to_string()) + .replace("{{https_port}}", &n_https_port.to_string()) + .replace("{{proxy_pass}}", &proxy_pass) + .replace("{{cert_path}}", &cert_path) + .replace("{{key_path}}", &key_path); + + let mut config_file = + NamedTempFile::new().expect("Failed to create temp file for nginx config"); + writeln!(config_file, "{}", nginx_conf).expect("Failed to write nginx config"); + let config_path = config_file.path().to_path_buf(); + let _child = Command::new("nginx") + .arg("-c") + .arg(config_path.as_os_str()) + .spawn() + .expect("Failed to start nginx"); + + let timeout = std::time::Duration::from_secs(5); + let start_time = std::time::Instant::now(); + loop { + match tokio::net::TcpStream::connect(format!("127.0.0.1:{}", n_https_port)).await { + Ok(_) => break, + Err(_) if start_time.elapsed() < timeout => { + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + } + Err(e) => return Err(Box::new(e)), + } + } + + // Keep the config file open as long as NGINX is using it + std::mem::forget(config_file); + + Ok(NginxProcess { _child, config_path }) + } + + fn full>(chunk: T) -> BoxBody { + Full::new(chunk.into()).map_err(|never| match never {}).boxed() + } +}