diff --git a/Cargo-minimal.lock b/Cargo-minimal.lock index 18687afb0..69a980115 100644 --- a/Cargo-minimal.lock +++ b/Cargo-minimal.lock @@ -2644,6 +2644,8 @@ dependencies = [ "bhttp", "bitcoin 0.32.7", "bitcoin-ohttp", + "clap", + "config", "futures", "http-body-util", "hyper", @@ -2652,6 +2654,7 @@ dependencies = [ "payjoin", "prometheus", "redis", + "serde", "tempfile", "tokio", "tokio-rustls", diff --git a/Cargo-recent.lock b/Cargo-recent.lock index 18687afb0..69a980115 100644 --- a/Cargo-recent.lock +++ b/Cargo-recent.lock @@ -2644,6 +2644,8 @@ dependencies = [ "bhttp", "bitcoin 0.32.7", "bitcoin-ohttp", + "clap", + "config", "futures", "http-body-util", "hyper", @@ -2652,6 +2654,7 @@ dependencies = [ "payjoin", "prometheus", "redis", + "serde", "tempfile", "tokio", "tokio-rustls", diff --git a/payjoin-directory/Cargo.toml b/payjoin-directory/Cargo.toml index 1d2bcb2cf..2059adbc0 100644 --- a/payjoin-directory/Cargo.toml +++ b/payjoin-directory/Cargo.toml @@ -34,6 +34,9 @@ tokio-rustls = { version = "0.26.2", features = ["ring"], default-features = fal tracing = "0.1.41" tracing-subscriber = { version = "0.3.19", features = ["env-filter"] } prometheus = "0.13.4" +clap = { version = "4.5.45", features = ["derive", "env"] } +config = "0.15.14" +serde = { version = "1.0.219", features = ["derive"] } [dev-dependencies] tempfile = "3.20.0" diff --git a/payjoin-directory/src/cli.rs b/payjoin-directory/src/cli.rs new file mode 100644 index 000000000..1e447fb80 --- /dev/null +++ b/payjoin-directory/src/cli.rs @@ -0,0 +1,55 @@ +use std::env; +use std::path::PathBuf; + +use clap::{value_parser, Parser}; + +#[derive(Debug, Parser)] +#[command( + version = env!("CARGO_PKG_VERSION"), + about = "Payjoin Directory Server", + long_about = None, +)] +pub struct Cli { + #[arg( + long, + short = 'p', + env = "PJ_DIR_PORT", + default_value = "8080", + help = "The port to bind" + )] + pub port: u16, // TODO tokio_listener::ListenerAddressLFlag + + #[arg( + long, + short = 'p', + env = "PJ_METRIC_PORT", + default_value = "9090", + help = "The port to bind for prometheus metrics export" + )] + pub metrics_port: u16, // TODO tokio_listener::ListenerAddressLFlag + + #[arg( + long, + env = "PJ_DIR_TIMEOUT_SECS", + default_value = "30", + help = "The timeout for long polling operations" + )] + pub timeout: u64, + + #[arg( + long = "db-host", + env = "PJ_DB_HOST", + default_value = "localhost:6379", + help = "The redis host to connect to" + )] + pub db_host: String, + + #[arg( + long = "ohttp-keys", + env = "PJ_OHTTP_KEY_DIR", + help = "The ohttp key config file path", + default_value = "ohttp_keys", + value_parser = value_parser!(PathBuf) + )] + pub ohttp_keys: PathBuf, +} diff --git a/payjoin-directory/src/config.rs b/payjoin-directory/src/config.rs new file mode 100644 index 000000000..4f649134b --- /dev/null +++ b/payjoin-directory/src/config.rs @@ -0,0 +1,52 @@ +use std::path::PathBuf; +use std::time::Duration; + +use anyhow::Result; +use config::builder::DefaultState; +use config::{ConfigError, File, FileFormat}; +use serde::Deserialize; + +type Builder = config::builder::ConfigBuilder; + +use crate::cli::Cli; + +#[derive(Debug, Clone, Deserialize)] +pub struct Config { + pub listen_addr: String, // TODO tokio_listener::ListenerAddressLFlag + pub metrics_listen_addr: String, // TODO tokio_listener::ListenerAddressLFlag + pub timeout: Duration, + pub db_host: String, + pub ohttp_keys: PathBuf, // TODO OhttpConfig struct with rotation params, etc +} + +impl Config { + pub fn new(cli: &Cli) -> Result { + let mut config = config::Config::builder(); + config = add_defaults(config, cli)?; + + // what directory should this reside in? require explicit --config-file? ~/.config? /etc? + config = config.add_source(File::new("config.toml", FileFormat::Toml).required(false)); + + let built_config = config.build()?; + + Ok(Config { + listen_addr: built_config.get("listen_addr")?, + metrics_listen_addr: built_config.get("metrics_listen_addr")?, + timeout: Duration::from_secs(built_config.get("timeout")?), + db_host: built_config.get("db_host")?, + ohttp_keys: built_config.get("ohttp_keys")?, + }) + } +} + +fn add_defaults(config: Builder, cli: &Cli) -> Result { + config + .set_override_option("listen_addr", Some(format!("[::]:{}", cli.port)))? + .set_override_option( + "metrics_listen_addr", + Some(format!("localhost:{}", cli.metrics_port)), + )? + .set_override_option("timeout", Some(cli.timeout))? + .set_override_option("db_host", Some(cli.db_host.to_owned()))? + .set_override_option("ohttp_keys", Some(cli.ohttp_keys.to_string_lossy().into_owned())) +} diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 0c2c4f033..71caff21a 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -16,10 +16,6 @@ pub use crate::db::DbPool; pub mod key_config; pub use crate::key_config::*; use crate::metrics::Metrics; -pub const DEFAULT_DIR_PORT: u16 = 8080; -pub const DEFAULT_DB_HOST: &str = "localhost:6379"; -pub const DEFAULT_TIMEOUT_SECS: u64 = 30; -pub const DEFAULT_METRIC_PORT: u16 = 9090; const CHACHA20_POLY1305_NONCE_LEN: usize = 32; // chacha20poly1305 n_k const POLY1305_TAG_SIZE: usize = 16; @@ -33,6 +29,8 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message" mod db; +pub mod cli; +pub mod config; pub mod metrics; pub type BoxError = Box; diff --git a/payjoin-directory/src/main.rs b/payjoin-directory/src/main.rs index 655675bb6..45fc825bd 100644 --- a/payjoin-directory/src/main.rs +++ b/payjoin-directory/src/main.rs @@ -1,36 +1,20 @@ -use std::env; -use std::net::{IpAddr, Ipv6Addr, SocketAddr}; - +use clap::Parser; use payjoin_directory::metrics::Metrics; use payjoin_directory::*; use tokio::net::TcpListener; +use tracing::error; use tracing_subscriber::filter::LevelFilter; use tracing_subscriber::EnvFilter; -const DEFAULT_KEY_CONFIG_DIR: &str = "ohttp_keys"; - #[tokio::main] async fn main() -> Result<(), BoxError> { init_logging(); - let dir_port = - env::var("PJ_DIR_PORT").map_or(DEFAULT_DIR_PORT, |s| s.parse().expect("Invalid port")); - - let metric_port = env::var("PJ_METRIC_PORT") - .map_or(DEFAULT_METRIC_PORT, |s| s.parse().expect("invalid metric port")); - - let timeout_env = env::var("PJ_DIR_TIMEOUT_SECS") - .map_or(DEFAULT_TIMEOUT_SECS, |s| s.parse().expect("Invalid timeout")); - let timeout = std::time::Duration::from_secs(timeout_env); + let cli = cli::Cli::parse(); + let config = config::Config::new(&cli)?; - let db_host = env::var("PJ_DB_HOST").unwrap_or_else(|_| DEFAULT_DB_HOST.to_string()); - - let key_dir = - std::env::var("PJ_OHTTP_KEY_DIR").map(std::path::PathBuf::from).unwrap_or_else(|_| { - let key_dir = std::path::PathBuf::from(DEFAULT_KEY_CONFIG_DIR); - std::fs::create_dir_all(&key_dir).expect("Failed to create key directory"); - key_dir - }); + let key_dir = config.ohttp_keys; + std::fs::create_dir_all(&key_dir).expect("Failed to create key directory"); let ohttp = match key_config::read_server_config(&key_dir) { Ok(config) => config, @@ -42,26 +26,24 @@ async fn main() -> Result<(), BoxError> { } }; - // Start metrics server in the background + let db = DbPool::new(config.timeout, config.db_host).await?; let metrics = Metrics::new(); - let metrics_listener = bind_port(metric_port).await?; - let listener = bind_port(dir_port).await?; - let db = DbPool::new(timeout, db_host).await?; let service = Service::new(db, ohttp.into(), metrics); - let service_clone = service.clone(); - tokio::spawn(async move { - if let Err(e) = payjoin_directory::serve_metrics_tcp(service_clone, metrics_listener).await - { - eprintln!("Metrics server error: {e}"); - } - }); - service.serve_tcp(listener).await -} -async fn bind_port(port: u16) -> Result { - let bind_addr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), port); - TcpListener::bind(bind_addr).await + // Start metrics server in the background + let metrics_listener = TcpListener::bind(config.metrics_listen_addr).await?; + { + let service = service.clone(); + tokio::spawn(async move { + if let Err(e) = payjoin_directory::serve_metrics_tcp(service, metrics_listener).await { + error!("Metrics server error: {e}"); + } + }); + } + + let listener = TcpListener::bind(config.listen_addr).await?; + service.serve_tcp(listener).await } fn init_logging() {