diff --git a/Cargo.lock b/Cargo.lock index 5d12c475..58d98fdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -201,7 +201,6 @@ dependencies = [ "libc", "log", "phf", - "rand 0.9.3", "reqwest", "rocket", "serde", @@ -209,6 +208,7 @@ dependencies = [ "thiserror 1.0.69", "tokio", "tokio-test", + "toml", ] [[package]] @@ -1167,6 +1167,20 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.24.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec3efd23720e2049821a693cbc7e65ea87c72f1c58ff2f9522ff332b1491e590" +dependencies = [ + "futures-util", + "http 0.2.12", + "hyper", + "rustls", + "tokio", + "tokio-rustls", +] + [[package]] name = "hyper-tls" version = "0.5.0" @@ -1823,7 +1837,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c80231409c20246a13fddb31776fb942c38553c51e871f8cbd687a4cfb5843d" dependencies = [ "phf_shared", - "rand 0.8.5", + "rand", ] [[package]] @@ -1975,18 +1989,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" dependencies = [ "libc", - "rand_chacha 0.3.1", - "rand_core 0.6.4", -] - -[[package]] -name = "rand" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ec095654a25171c2124e9e3393a930bddbffdc939556c914957a4c3e0a87166" -dependencies = [ - "rand_chacha 0.9.0", - "rand_core 0.9.3", + "rand_chacha", + "rand_core", ] [[package]] @@ -1996,17 +2000,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" dependencies = [ "ppv-lite86", - "rand_core 0.6.4", -] - -[[package]] -name = "rand_chacha" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" -dependencies = [ - "ppv-lite86", - "rand_core 0.9.3", + "rand_core", ] [[package]] @@ -2018,15 +2012,6 @@ dependencies = [ "getrandom 0.2.16", ] -[[package]] -name = "rand_core" -version = "0.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" -dependencies = [ - "getrandom 0.3.3", -] - [[package]] name = "rayon" version = "1.10.0" @@ -2140,6 +2125,7 @@ dependencies = [ "http 0.2.12", "http-body", "hyper", + "hyper-rustls", "hyper-tls", "ipnet", "js-sys", @@ -2149,6 +2135,8 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", + "rustls", + "rustls-native-certs", "rustls-pemfile", "serde", "serde_json", @@ -2157,6 +2145,7 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-rustls", "tower-service", "url", "wasm-bindgen", @@ -2165,6 +2154,20 @@ dependencies = [ "winreg", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + [[package]] name = "rocket" version = "0.5.1" @@ -2186,7 +2189,7 @@ dependencies = [ "num_cpus", "parking_lot", "pin-project-lite", - "rand 0.8.5", + "rand", "ref-cast", "rocket_codegen", "rocket_http", @@ -2334,6 +2337,30 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "rustls" +version = "0.21.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f56a14d1f48b391359b22f731fd4bd7e43c97f3c50eee276f3aa09c94784d3e" +dependencies = [ + "log", + "ring", + "rustls-webpki", + "sct", +] + +[[package]] +name = "rustls-native-certs" +version = "0.6.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "schannel", + "security-framework", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2343,6 +2370,16 @@ dependencies = [ "base64", ] +[[package]] +name = "rustls-webpki" +version = "0.101.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.21" @@ -2410,6 +2447,16 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "sct" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" +dependencies = [ + "ring", + "untrusted", +] + [[package]] name = "sd-notify" version = "0.4.5" @@ -2832,6 +2879,16 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-rustls" +version = "0.24.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" +dependencies = [ + "rustls", + "tokio", +] + [[package]] name = "tokio-stream" version = "0.1.17" @@ -3036,6 +3093,12 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.4" diff --git a/aw-client-rust/Cargo.toml b/aw-client-rust/Cargo.toml index 96e1ff04..8b0bc624 100644 --- a/aw-client-rust/Cargo.toml +++ b/aw-client-rust/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" authors = ["Johan Bjäreholt "] [dependencies] -reqwest = { version = "0.11", features = ["json", "blocking"] } +reqwest = { version = "0.11", default-features = false, features = ["json", "blocking", "rustls-tls-native-roots"] } gethostname = "0.4" serde = { version = "1.0", features = ["derive"] } phf = { version = "0.11", features = ["macros"] } @@ -13,10 +13,9 @@ serde_json = "1.0" chrono = { version = "0.4", features = ["serde"] } aw-models = { path = "../aw-models" } tokio = { version = "1.28.2", features = ["rt"] } -rand = "0.9" log = "0.4" libc = "0.2" -thiserror = "1.0" +thiserror = "1.0" dirs = "6.0" fs4 = { version = "0.13", features = ["sync"] } @@ -25,3 +24,4 @@ aw-datastore = { path = "../aw-datastore" } aw-server = { path = "../aw-server", default-features = false, features=[] } rocket = "0.5.0-rc.1" tokio-test = "*" +toml = "0.8" diff --git a/aw-client-rust/src/blocking.rs b/aw-client-rust/src/blocking.rs index c003072c..fbcd91fa 100644 --- a/aw-client-rust/src/blocking.rs +++ b/aw-client-rust/src/blocking.rs @@ -38,7 +38,16 @@ macro_rules! proxy_method impl AwClient { pub fn new(host: &str, port: u16, name: &str) -> Result> { - let async_client = AsyncAwClient::new(host, port, name)?; + Self::new_with_api_key(host, port, name, None) + } + + pub fn new_with_api_key( + host: &str, + port: u16, + name: &str, + api_key: Option, + ) -> Result> { + let async_client = AsyncAwClient::new_with_api_key(host, port, name, api_key)?; Ok(AwClient { baseurl: async_client.baseurl.clone(), diff --git a/aw-client-rust/src/lib.rs b/aw-client-rust/src/lib.rs index 42368604..74fb9c8b 100644 --- a/aw-client-rust/src/lib.rs +++ b/aw-client-rust/src/lib.rs @@ -13,6 +13,7 @@ pub mod single_instance; use std::{collections::HashMap, error::Error}; use chrono::{DateTime, Utc}; +use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION}; use serde_json::{json, Map}; use single_instance::SingleInstance; use std::net::TcpStream; @@ -39,13 +40,34 @@ fn get_hostname() -> String { gethostname::gethostname().to_string_lossy().to_string() } +fn build_client(api_key: Option) -> Result> { + let mut headers = HeaderMap::new(); + if let Some(api_key) = api_key { + let mut header_value = HeaderValue::from_str(&format!("Bearer {api_key}"))?; + header_value.set_sensitive(true); + headers.insert(AUTHORIZATION, header_value); + } + + Ok(reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(120)) + .default_headers(headers) + .build()?) +} + impl AwClient { pub fn new(host: &str, port: u16, name: &str) -> Result> { + Self::new_with_api_key(host, port, name, None) + } + + pub fn new_with_api_key( + host: &str, + port: u16, + name: &str, + api_key: Option, + ) -> Result> { let baseurl = reqwest::Url::parse(&format!("http://{}:{}", host, port))?; let hostname = get_hostname(); - let client = reqwest::Client::builder() - .timeout(std::time::Duration::from_secs(120)) - .build()?; + let client = build_client(api_key)?; //TODO: change localhost string to 127.0.0.1 for feature parity let single_instance_name = format!("{}-at-{}-on-{}", name, host, port); let single_instance = single_instance::SingleInstance::new(single_instance_name.as_str())?; diff --git a/aw-client-rust/tests/test.rs b/aw-client-rust/tests/test.rs index d518aa0a..9b8928d2 100644 --- a/aw-client-rust/tests/test.rs +++ b/aw-client-rust/tests/test.rs @@ -12,13 +12,87 @@ mod test { use aw_client_rust::Event; use chrono::{DateTime, Duration, Utc}; use serde_json::Map; + use std::cell::RefCell; + use std::fs; + use std::net::TcpListener; + use std::path::{Path, PathBuf}; use std::sync::Mutex; use std::thread; + use std::time::{SystemTime, UNIX_EPOCH}; use tokio_test::block_on; + // Config-reading helpers — only needed in tests; production code passes API keys explicitly. + + #[derive(serde::Deserialize, Default)] + struct LocalAuthConfig { + #[serde(default)] + api_key: Option, + } + + #[derive(serde::Deserialize, Default)] + struct LocalServerConfig { + #[serde(default)] + port: Option, + #[serde(default)] + auth: LocalAuthConfig, + } + + #[derive(Clone, Copy)] + struct ConfigCandidate { + filename: &'static str, + default_port: u16, + } + + fn get_server_config_dir() -> Option { + Some( + dirs::config_dir()? + .join("activitywatch") + .join("aw-server-rust"), + ) + } + + fn load_local_api_key(host: &str, port: u16) -> Option { + if host != "127.0.0.1" && host != "localhost" { + return None; + } + let config_dir = get_server_config_dir()?; + let candidates = [ + ConfigCandidate { + filename: "config.toml", + default_port: 5600, + }, + ConfigCandidate { + filename: "config-testing.toml", + default_port: 5666, + }, + ]; + for candidate in candidates { + let path = config_dir.join(candidate.filename); + let content = match fs::read_to_string(path) { + Ok(content) => content, + Err(_) => continue, + }; + let config: LocalServerConfig = match toml::from_str(&content) { + Ok(config) => config, + Err(_) => continue, + }; + let configured_port = config.port.unwrap_or(candidate.default_port); + if configured_port == port { + return config.auth.api_key.filter(|k| !k.is_empty()); + } + } + None + } + // A random port, but still not guaranteed to not be bound // FIXME: Bind to a port that is free for certain and use that for the client instead static PORT: u16 = 41293; + static ENV_LOCK: Mutex<()> = Mutex::new(()); + + // Keep the listener alive until the server binds — prevents TOCTOU race in reserve_port + thread_local! { + static RESERVED_PORT: RefCell> = RefCell::new(None); + } fn wait_for_server(timeout_s: u32, client: &AwClient) { for i in 0.. { @@ -36,7 +110,7 @@ mod test { } } - fn setup_testserver() -> rocket::Shutdown { + fn setup_testserver(port: u16, api_key: Option<&str>) -> rocket::Shutdown { use aw_server::endpoints::AssetResolver; use aw_server::endpoints::ServerState; @@ -46,7 +120,8 @@ mod test { device_id: "test_id".to_string(), }; let mut aw_config = aw_server::config::AWConfig::default(); - aw_config.port = PORT; + aw_config.port = port; + aw_config.auth.api_key = api_key.map(str::to_owned); let server = aw_server::endpoints::build_rocket(state, aw_config); let server = block_on(server.ignite()).unwrap(); let shutdown_handler = server.shutdown(); @@ -58,14 +133,64 @@ mod test { shutdown_handler } + fn reserve_port() -> u16 { + let listener = TcpListener::bind("127.0.0.1:0").unwrap(); + let port = listener.local_addr().unwrap().port(); + // Keep the listener alive until the server binds — prevents TOCTOU race + RESERVED_PORT.with(|cell| *cell.borrow_mut() = Some(listener)); + port + } + + fn write_server_config(port: u16, api_key: Option<&str>) -> PathBuf { + let unique = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_nanos(); + let config_home = std::env::temp_dir().join(format!( + "aw-client-rust-config-{}-{}", + std::process::id(), + unique + )); + let config_dir = config_home.join("activitywatch").join("aw-server-rust"); + fs::create_dir_all(&config_dir).unwrap(); + + let mut content = format!("port = {port}\n"); + if let Some(api_key) = api_key { + content.push_str("\n[auth]\n"); + content.push_str(&format!("api_key = \"{api_key}\"\n")); + } + fs::write(config_dir.join("config.toml"), content).unwrap(); + + config_home + } + + fn with_config_home(config_home: &Path, f: impl FnOnce() -> T) -> T { + let _guard = ENV_LOCK.lock().unwrap(); + let old_value = std::env::var_os("XDG_CONFIG_HOME"); + std::env::set_var("XDG_CONFIG_HOME", config_home); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(f)); + if let Some(old_value) = old_value { + std::env::set_var("XDG_CONFIG_HOME", old_value); + } else { + std::env::remove_var("XDG_CONFIG_HOME"); + } + let _ = fs::remove_dir_all(config_home); + result.unwrap() + } + #[test] fn test_full() { let clientname = "aw-client-rust-test"; - let client: AwClient = - AwClient::new("127.0.0.1", PORT, clientname).expect("Client creation failed"); + // Hold ENV_LOCK during client creation to prevent parallel-test interference + // with test_reads_api_key_from_matching_server_config (which holds the lock + // via with_config_home for the entire client+server lifetime). + let client: AwClient = { + let _guard = ENV_LOCK.lock().unwrap(); + AwClient::new("127.0.0.1", PORT, clientname).expect("Client creation failed") + }; - let shutdown_handler = setup_testserver(); + let shutdown_handler = setup_testserver(PORT, None); wait_for_server(20, &client); @@ -137,4 +262,37 @@ RETURN = events;", shutdown_handler.notify(); } + + // XDG_CONFIG_HOME is only respected by dirs::config_dir() on Linux. + // On macOS it returns $HOME/Library/Application Support (ignoring XDG_CONFIG_HOME), + // so this test would fail — gate it on Linux only. + #[test] + #[cfg(target_os = "linux")] + fn test_reads_api_key_from_matching_server_config() { + let clientname = "aw-client-rust-auth-test"; + let port = reserve_port(); + let config_home = write_server_config(port, Some("secret123")); + + with_config_home(&config_home, || { + let api_key = load_local_api_key("127.0.0.1", port); + let client: AwClient = + AwClient::new_with_api_key("127.0.0.1", port, clientname, api_key) + .expect("Client creation failed"); + // Drop the reserved listener before Rocket tries to bind the same port. + RESERVED_PORT.with(|cell| *cell.borrow_mut() = None); + let shutdown_handler = setup_testserver(port, Some("secret123")); + + wait_for_server(20, &client); + + let bucketname = format!("aw-client-rust-auth-test_{}", client.hostname); + client + .create_bucket_simple(&bucketname, "test-type") + .unwrap(); + + let bucket = client.get_bucket(&bucketname).unwrap(); + assert_eq!(bucket.id, bucketname); + + shutdown_handler.notify(); + }); + } }