diff --git a/doc/guide/authentication.md b/doc/guide/authentication.md new file mode 100644 index 000000000..5bd727587 --- /dev/null +++ b/doc/guide/authentication.md @@ -0,0 +1,133 @@ +--- +title: moq-authentication +description: Authentication for the moq-relay +--- + +# MoQ authentication + +The MoQ Relay authenticates via JWT-based tokens. Generally there are two different approaches you can choose from: +- asymmetric keys: using a public and private key to separate signing and verifying keys for more security +- symmetric key: using a single secret key for signing and verifying, less secure + +## Symmetric key + +1. Generate a secret key: +```bash +moq-token --key root.jwk generate --algorithm HS256 +``` +:::details You can also choose a different algorithm +- HS256 +- HS384 +- HS512 +::: + +2. Configure relay: +:::code-group +```toml [relay.toml] +[auth] +# public = "anon" # Optional: allow anonymous access to anon/** +key = "root.jwk" # JWT key for authenticated paths +``` +::: + +3. Generate tokens: +```bash +moq-token --key root.jwk sign \ + --root "rooms/123" \ + --publish "alice" \ + --subscribe "" \ + --expires 1735689600 > alice.jwt +``` + +## Asymmetric keys + +Generally asymmetric keys can be more secure because you don't need to distribute the signing key to every relay instance, the relays only need to verifying (public) key. + +1. Generate a public and private key: +```bash +moq-token --key private.jwk generate --public public.jwk --algorithm RS256 +``` +:::details You can also choose a different algorithm +- RS256 +- RS384 +- RS512 +- PS256 +- PS384 +- PS512 +- EC256 +- EC384 +- EdDSA +::: + +2. Now the relay only requires the public key: +:::code-group +```toml [relay.toml] +[auth] +# public = "anon" # Optional: allow anonymous access to anon/** +key = "public.jwk" # JWT key for authenticated paths +``` +::: + +3. Generate tokens using the private key: +```bash +moq-token --key private.jwk sign \ + --root "rooms/123" \ + --publish "alice" \ + --subscribe "" \ + --expires 1735689600 > alice.jwt +``` + +## JWK set authentication + +Instead of storing a public key locally in a file, it may also be retrieved from a server hosting a JWK set. This can be a simple static site serving a JSON file, or a fully OIDC compliant Identity Provider. That way you can easily implement automatic key rotation. + +::: info +This approach only works with asymmetric authentication. +::: + +To set this up, you need to have an HTTPS server hosting a JWK set that looks like this: +```json +{ + "keys": [ + { + "kid": "2026-01-01", + "alg": "RS256", + "key_ops": [ + "verify" + ], + "kty": "RSA", + "n": "zMsjX1oDV2SMQKZFTx4_qCaD3iIek9s1lvVaymr8bEGzO4pe6syCwBwLmFwaixRv7MMsuZ0nIpoR3Slpo-ZVyRxOc8yc3DcBZx49S_UQcM76E4MYbH6oInrEP8QL2bsstHrYTqTyPPjGwQJVp_sZdkjKlF5N-v5ohpn36sI8PXELvfRY3O3bad-RmSZ8ZOG8CYnJvMj_g2lYtGMMThnddnJ49560ahUNqAbH6ru---sHtdYHcjTIaWX4HYP6Y_KjA6siDZTGTThpaEW45LKcDQWM9sYvx_eAstaC-1rz8Z_6fDgKFWr7qcP5U2NmJ0c-IGSu_8OkftgRH4--Z5mzBQ", + "e": "AQAB" + }, + { + "kid": "2025-12-01", + "alg": "EdDSA", + "key_ops": [ + "verify" + ], + "kty": "OKP", + "crv": "Ed25519", + "x": "2FSK2q_o_d5ernBmNQLNMFxiA4-ypBSa4LsN30ZjUeU" + } + ] +} +``` + +:::tip The following must be considered: +- Every JWK MUST be public and contain no private key information +- If your JWK set contains more than one key: + 1. Every JWK MUST have a `kid` so they can be identified on verification + 2. Your JWT tokens MUST contain a `kid` in their header + 3. `kid` can be an arbitrary string +::: + +Configure the relay: +:::code-group +```toml [relay.toml] +[auth] +# public = "anon" # Optional: allow anonymous access to anon/** + +key = "https://auth.example.com/keys.json" # JWK set URL for authenticated paths +jwks_refresh_interval = 86400 # Optional: refresh the JWK set every N seconds, no refreshing if omitted +``` +::: diff --git a/rs/moq-relay/Cargo.toml b/rs/moq-relay/Cargo.toml index ffbeac6ec..3c6d6c357 100644 --- a/rs/moq-relay/Cargo.toml +++ b/rs/moq-relay/Cargo.toml @@ -30,7 +30,7 @@ futures = "0.3" http-body = "1" moq-lite = { workspace = true, features = ["serde"] } moq-native = { workspace = true, features = ["aws-lc-rs"] } -moq-token = { workspace = true } +moq-token = { workspace = true, features = ["jwks-loader"] } rustls = { version = "0.23", features = [ "aws-lc-rs", ], default-features = false } diff --git a/rs/moq-relay/src/auth.rs b/rs/moq-relay/src/auth.rs index 060602b17..45b48f04b 100644 --- a/rs/moq-relay/src/auth.rs +++ b/rs/moq-relay/src/auth.rs @@ -1,8 +1,10 @@ -use std::sync::Arc; - +use anyhow::Context; use axum::http; use moq_lite::{AsPath, Path, PathOwned}; +use moq_token::KeySet; use serde::{Deserialize, Serialize}; +use std::sync::{Arc, Mutex}; +use std::time::Duration; #[derive(thiserror::Error, Debug, Clone)] pub enum AuthError { @@ -34,11 +36,17 @@ impl axum::response::IntoResponse for AuthError { #[derive(clap::Args, Clone, Debug, Serialize, Deserialize, Default)] #[serde(default)] pub struct AuthConfig { - /// The root authentication key. + /// Either the root authentication key or a URI to a JWK set. /// If present, all paths will require a token unless they are in the public list. #[arg(long = "auth-key", env = "MOQ_AUTH_KEY")] pub key: Option, + /// How often to refresh the JWK set (in seconds), will be ignored if the `key` is not a valid URI. + /// If not provided, there won't be any refreshing, the JWK set will only be loaded once at startup. + /// Minimum value: 30, defaults to None + #[arg(long = "jwks-refresh-interval", env = "MOQ_AUTH_JWKS_REFRESH_INTERVAL")] + pub jwks_refresh_interval: Option, + /// The prefix that will be public for reading and writing. /// If present, unauthorized users will be able to read and write to this prefix ONLY. /// If a user provides a token, then they can only access the prefix only if it is specified in the token. @@ -47,8 +55,8 @@ pub struct AuthConfig { } impl AuthConfig { - pub fn init(self) -> anyhow::Result { - Auth::new(self) + pub async fn init(self) -> anyhow::Result { + Auth::new(self).await } } @@ -60,15 +68,119 @@ pub struct AuthToken { pub cluster: bool, } +const JWKS_REFRESH_ERROR_INTERVAL: Duration = Duration::from_mins(5); + #[derive(Clone)] pub struct Auth { - key: Option>, + key: Option>>, public: Option, + refresh_task: Option>>, +} + +impl Drop for Auth { + fn drop(&mut self) { + if let Some(handle) = self.refresh_task.as_ref() + && Arc::strong_count(handle) == 1 + { + handle.abort(); + } + } } impl Auth { - pub fn new(config: AuthConfig) -> anyhow::Result { - let key = config.key.as_deref().map(moq_token::Key::from_file).transpose()?; + fn compare_key_sets(previous: &KeySet, new: &KeySet) { + for new_key in new.keys.iter() { + if new_key.kid.is_some() && !previous.keys.iter().any(|k| k.kid == new_key.kid) { + tracing::info!("found new JWK \"{}\"", new_key.kid.as_deref().unwrap()) + } + } + for old_key in previous.keys.iter() { + if old_key.kid.is_some() && !new.keys.iter().any(|k| k.kid == old_key.kid) { + tracing::info!("removed JWK \"{}\"", old_key.kid.as_deref().unwrap()) + } + } + } + + async fn refresh_key_set(jwks_uri: &str, key_set: &Mutex) -> anyhow::Result<()> { + let new_keys = moq_token::load_keys(jwks_uri).await?; + + let mut key_set = key_set.lock().expect("keyset mutex poisoned"); + Self::compare_key_sets(&key_set, &new_keys); + *key_set = new_keys; + + Ok(()) + } + + fn spawn_refresh_task( + interval: Duration, + key_set: Arc>, + jwks_uri: String, + ) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + loop { + tokio::time::sleep(interval).await; + + if let Err(e) = Self::refresh_key_set(&jwks_uri, key_set.as_ref()).await { + if interval > JWKS_REFRESH_ERROR_INTERVAL * 2 { + tracing::error!( + "failed to load JWKS, will retry in {} seconds: {:?}", + JWKS_REFRESH_ERROR_INTERVAL.as_secs(), + e + ); + tokio::time::sleep(JWKS_REFRESH_ERROR_INTERVAL).await; + + if let Err(e) = Self::refresh_key_set(&jwks_uri, key_set.as_ref()).await { + tracing::error!("failed to load JWKS again, giving up this time: {:?}", e); + } else { + tracing::info!("successfully loaded JWKS on the second try"); + } + } else { + // Don't retry because the next refresh is going to happen very soon + tracing::error!("failed to refresh JWKS: {:?}", e); + } + } + } + }) + } + + pub async fn new(config: AuthConfig) -> anyhow::Result { + let mut refresh_task = None; + + let key = match config.key { + Some(uri) if uri.starts_with("http://") || uri.starts_with("https://") => { + // Start with an empty KeySet + let key_set = Arc::new(Mutex::new(KeySet::default())); + + tracing::info!("loading JWK set from {}", &uri); + + Self::refresh_key_set(&uri, key_set.as_ref()).await?; + + if let Some(refresh_interval_secs) = config.jwks_refresh_interval { + anyhow::ensure!( + refresh_interval_secs >= 30, + "jwks_refresh_interval cannot be less than 30" + ); + + refresh_task = Some(Self::spawn_refresh_task( + Duration::from_secs(refresh_interval_secs), + key_set.clone(), + uri, + )); + } + + Some(key_set) + } + + Some(key_file) => { + let key = moq_token::Key::from_file(&key_file) + .with_context(|| format!("cannot load key from {}", &key_file))?; + Some(Arc::new(Mutex::new(KeySet { + keys: vec![Arc::new(key)], + }))) + } + + None => None, + }; let public = config.public; @@ -79,8 +191,9 @@ impl Auth { } Ok(Self { - key: key.map(Arc::new), + key, public: public.map(|p| p.as_path().to_owned()), + refresh_task: refresh_task.map(Arc::new), }) } @@ -90,9 +203,12 @@ impl Auth { // Find the token in the query parameters. // ?jwt=... let claims = if let Some(token) = token - && let Some(key) = self.key.as_ref() + && let Some(key) = self.key.as_deref() { - key.decode(token).map_err(|_| AuthError::DecodeFailed)? + key.lock() + .expect("key mutex poisoned") + .decode(token) + .map_err(|_| AuthError::DecodeFailed)? } else if let Some(_token) = token { return Err(AuthError::UnexpectedToken); } else if let Some(public) = &self.public { @@ -115,7 +231,7 @@ impl Auth { return Err(AuthError::IncorrectRoot); }; - // If a more specific path is is provided, reduce the permissions. + // If a more specific path is provided, reduce the permissions. let subscribe = claims .subscribe .into_iter() @@ -164,13 +280,14 @@ mod tests { Ok((key_file, key)) } - #[test] - fn test_anonymous_access_with_public_path() -> anyhow::Result<()> { + #[tokio::test] + async fn test_anonymous_access_with_public_path() -> anyhow::Result<()> { // Test anonymous access to /anon path let auth = Auth::new(AuthConfig { - key: None, public: Some("anon".to_string()), - })?; + ..Default::default() + }) + .await?; // Should succeed for anonymous path let token = auth.verify("/anon", None)?; @@ -187,13 +304,14 @@ mod tests { Ok(()) } - #[test] - fn test_anonymous_access_fully_public() -> anyhow::Result<()> { + #[tokio::test] + async fn test_anonymous_access_fully_public() -> anyhow::Result<()> { // Test fully public access (public = "") let auth = Auth::new(AuthConfig { - key: None, public: Some("".to_string()), - })?; + ..Default::default() + }) + .await?; // Should succeed for any path let token = auth.verify("/any/path", None)?; @@ -204,13 +322,14 @@ mod tests { Ok(()) } - #[test] - fn test_anonymous_access_denied_wrong_prefix() -> anyhow::Result<()> { + #[tokio::test] + async fn test_anonymous_access_denied_wrong_prefix() -> anyhow::Result<()> { // Test anonymous access denied for wrong prefix let auth = Auth::new(AuthConfig { - key: None, public: Some("anon".to_string()), - })?; + ..Default::default() + }) + .await?; // Should fail for non-anonymous path let result = auth.verify("/secret", None); @@ -219,13 +338,14 @@ mod tests { Ok(()) } - #[test] - fn test_no_token_no_public_path_fails() -> anyhow::Result<()> { + #[tokio::test] + async fn test_no_token_no_public_path_fails() -> anyhow::Result<()> { let (key_file, _) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Should fail when no token and no public path let result = auth.verify("/any/path", None); @@ -234,12 +354,13 @@ mod tests { Ok(()) } - #[test] - fn test_token_provided_but_no_key_configured() -> anyhow::Result<()> { + #[tokio::test] + async fn test_token_provided_but_no_key_configured() -> anyhow::Result<()> { let auth = Auth::new(AuthConfig { - key: None, public: Some("anon".to_string()), - })?; + ..Default::default() + }) + .await?; // Should fail when token provided but no key configured let result = auth.verify("/any/path", Some("fake-token")); @@ -248,13 +369,14 @@ mod tests { Ok(()) } - #[test] - fn test_jwt_token_basic_validation() -> anyhow::Result<()> { + #[tokio::test] + async fn test_jwt_token_basic_validation() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a token with basic permissions let claims = moq_token::Claims { @@ -274,13 +396,14 @@ mod tests { Ok(()) } - #[test] - fn test_jwt_token_wrong_root_path() -> anyhow::Result<()> { + #[tokio::test] + async fn test_jwt_token_wrong_root_path() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a token for room/123 let claims = moq_token::Claims { @@ -298,13 +421,14 @@ mod tests { Ok(()) } - #[test] - fn test_jwt_token_with_restricted_publish_subscribe() -> anyhow::Result<()> { + #[tokio::test] + async fn test_jwt_token_with_restricted_publish_subscribe() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a token with specific pub/sub restrictions let claims = moq_token::Claims { @@ -324,13 +448,14 @@ mod tests { Ok(()) } - #[test] - fn test_jwt_token_read_only() -> anyhow::Result<()> { + #[tokio::test] + async fn test_jwt_token_read_only() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a read-only token (no publish permissions) let claims = moq_token::Claims { @@ -348,13 +473,14 @@ mod tests { Ok(()) } - #[test] - fn test_jwt_token_write_only() -> anyhow::Result<()> { + #[tokio::test] + async fn test_jwt_token_write_only() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a write-only token (no subscribe permissions) let claims = moq_token::Claims { @@ -372,13 +498,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_basic() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_basic() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Create a token with root at room/123 and unrestricted pub/sub let claims = moq_token::Claims { @@ -401,13 +528,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_with_publish_restrictions() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_with_publish_restrictions() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Token allows publishing only to alice/* let claims = moq_token::Claims { @@ -430,13 +558,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_with_subscribe_restrictions() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_with_subscribe_restrictions() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Token allows subscribing only to bob/* let claims = moq_token::Claims { @@ -458,13 +587,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_loses_access() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_loses_access() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Token allows publishing to alice/* and subscribing to bob/* let claims = moq_token::Claims { @@ -496,13 +626,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_nested_paths() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_nested_paths() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Token with nested publish/subscribe paths let claims = moq_token::Claims { @@ -533,13 +664,14 @@ mod tests { Ok(()) } - #[test] - fn test_claims_reduction_preserves_read_write_only() -> anyhow::Result<()> { + #[tokio::test] + async fn test_claims_reduction_preserves_read_write_only() -> anyhow::Result<()> { let (key_file, key) = create_test_key()?; let auth = Auth::new(AuthConfig { key: Some(key_file.path().to_string_lossy().to_string()), - public: None, - })?; + ..Default::default() + }) + .await?; // Read-only token let claims = moq_token::Claims { diff --git a/rs/moq-relay/src/main.rs b/rs/moq-relay/src/main.rs index 0411ea82f..b188ce984 100644 --- a/rs/moq-relay/src/main.rs +++ b/rs/moq-relay/src/main.rs @@ -43,7 +43,7 @@ async fn main() -> anyhow::Result<()> { client.with_iroh(iroh); } - let auth = config.auth.init()?; + let auth = config.auth.init().await?; let cluster = Cluster::new(config.cluster, client); let cloned = cluster.clone(); diff --git a/rs/moq-token/Cargo.toml b/rs/moq-token/Cargo.toml index 6937f4358..2ac7390f1 100644 --- a/rs/moq-token/Cargo.toml +++ b/rs/moq-token/Cargo.toml @@ -8,6 +8,9 @@ license = "MIT OR Apache-2.0" version = "0.5.6" edition = "2024" +[features] +jwks-loader = ["reqwest"] + [dependencies] anyhow = "1" aws-lc-rs = "1" @@ -16,6 +19,7 @@ elliptic-curve = "0.13.8" jsonwebtoken = { version = "10", features = ["aws_lc_rs"] } p256 = "0.13.2" p384 = "0.13.1" +reqwest = { version = "0.13.1", optional = true } rsa = "0.9.10" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/rs/moq-token/src/lib.rs b/rs/moq-token/src/lib.rs index 38ffcc84b..a8ca60387 100644 --- a/rs/moq-token/src/lib.rs +++ b/rs/moq-token/src/lib.rs @@ -9,7 +9,9 @@ mod algorithm; mod claims; mod generate; mod key; +mod set; pub use algorithm::*; pub use claims::*; pub use key::*; +pub use set::*; diff --git a/rs/moq-token/src/set.rs b/rs/moq-token/src/set.rs new file mode 100644 index 000000000..9b8778842 --- /dev/null +++ b/rs/moq-token/src/set.rs @@ -0,0 +1,421 @@ +use crate::{Claims, Key, KeyOperation}; +use anyhow::Context; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; +use std::path::Path; +use std::sync::Arc; +use std::time::Duration; + +/// JWK Set to spec +#[derive(Default, Clone)] +pub struct KeySet { + /// Vec of an arbitrary number of Json Web Keys + pub keys: Vec>, +} + +impl Serialize for KeySet { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + // Serialize as a struct with a `keys` field + use serde::ser::SerializeStruct; + + let mut state = serializer.serialize_struct("KeySet", 1)?; + state.serialize_field("keys", &self.keys.iter().map(|k| k.as_ref()).collect::>())?; + state.end() + } +} + +impl<'de> Deserialize<'de> for KeySet { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + // Deserialize into a temporary Vec + #[derive(Deserialize)] + struct RawKeySet { + keys: Vec, + } + + let raw = RawKeySet::deserialize(deserializer)?; + Ok(KeySet { + keys: raw.keys.into_iter().map(Arc::new).collect(), + }) + } +} + +impl KeySet { + #[allow(clippy::should_implement_trait)] + pub fn from_str(s: &str) -> anyhow::Result { + Ok(serde_json::from_str(s)?) + } + + pub fn from_file>(path: P) -> anyhow::Result { + let json = std::fs::read_to_string(&path)?; + Ok(serde_json::from_str(&json)?) + } + + pub fn to_str(&self) -> anyhow::Result { + Ok(serde_json::to_string(&self)?) + } + + pub fn to_file>(&self, path: P) -> anyhow::Result<()> { + let json = serde_json::to_string(&self)?; + std::fs::write(path, json)?; + Ok(()) + } + + pub fn to_public_set(&self) -> anyhow::Result { + Ok(KeySet { + keys: self + .keys + .iter() + .map(|key| { + key.as_ref() + .to_public() + .map(Arc::new) + .map_err(|e| anyhow::anyhow!("failed to get public key from jwks: {:?}", e)) + }) + .collect::>, _>>()?, + }) + } + + pub fn find_key(&self, kid: &str) -> Option> { + self.keys.iter().find(|k| k.kid.as_deref() == Some(kid)).cloned() + } + + pub fn find_supported_key(&self, operation: &KeyOperation) -> Option> { + self.keys.iter().find(|key| key.operations.contains(operation)).cloned() + } + + pub fn encode(&self, payload: &Claims) -> anyhow::Result { + let key = self + .find_supported_key(&KeyOperation::Sign) + .context("cannot find signing key")?; + key.encode(payload) + } + + pub fn decode(&self, token: &str) -> anyhow::Result { + let header = jsonwebtoken::decode_header(token).context("failed to decode JWT header")?; + + let key = match header.kid { + Some(kid) => self + .find_key(kid.as_str()) + .ok_or_else(|| anyhow::anyhow!("cannot find key with kid {kid}")), + None => { + // If we only have one key we can use it without a kid + if self.keys.len() == 1 { + Ok(self.keys[0].clone()) + } else { + anyhow::bail!("missing kid in JWT header") + } + } + }?; + + key.decode(token) + } +} + +#[cfg(feature = "jwks-loader")] +pub async fn load_keys(jwks_uri: &str) -> anyhow::Result { + let client = reqwest::Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .context("failed to build reqwest client")?; + + let jwks_json = client + .get(jwks_uri) + .send() + .await + .with_context(|| format!("failed to GET JWKS from {}", jwks_uri))? + .error_for_status() + .with_context(|| format!("JWKS endpoint returned error: {}", jwks_uri))? + .text() + .await + .context("failed to read JWKS response body")?; + + // Parse the JWKS into a KeySet + KeySet::from_str(&jwks_json).context("Failed to parse JWKS into KeySet") +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Algorithm; + use std::time::{Duration, SystemTime}; + + fn create_test_claims() -> Claims { + Claims { + root: "test-path".to_string(), + publish: vec!["test-pub".into()], + cluster: false, + subscribe: vec!["test-sub".into()], + expires: Some(SystemTime::now() + Duration::from_secs(3600)), + issued: Some(SystemTime::now()), + } + } + + fn create_test_key(kid: Option) -> Key { + Key::generate(Algorithm::ES256, kid).expect("failed to generate key") + } + + #[test] + fn test_keyset_from_str_valid() { + let json = r#"{"keys":[{"kty":"oct","k":"2AJvfDJMVfWe9WMRPJP-4zCGN8F62LOy3dUr--rogR8","alg":"HS256","key_ops":["verify","sign"],"kid":"1"}]}"#; + let set = KeySet::from_str(json); + assert!(set.is_ok()); + let set = set.unwrap(); + assert_eq!(set.keys.len(), 1); + assert_eq!(set.keys[0].kid.as_deref(), Some("1")); + assert!(set.find_key("1").is_some()); + } + + #[test] + fn test_keyset_from_str_invalid_json() { + let result = KeySet::from_str("invalid json"); + assert!(result.is_err()); + } + + #[test] + fn test_keyset_from_str_empty() { + let json = r#"{"keys":[]}"#; + let set = KeySet::from_str(json).unwrap(); + assert!(set.keys.is_empty()); + } + + #[test] + fn test_keyset_to_str() { + let key = create_test_key(Some("1".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let json = set.to_str().unwrap(); + assert!(json.contains("\"keys\"")); + assert!(json.contains("\"kid\":\"1\"")); + } + + #[test] + fn test_keyset_serde_round_trip() { + let key1 = create_test_key(Some("1".to_string())); + let key2 = create_test_key(Some("2".to_string())); + let set = KeySet { + keys: vec![Arc::new(key1), Arc::new(key2)], + }; + + let json = set.to_str().unwrap(); + let deserialized = KeySet::from_str(&json).unwrap(); + + assert_eq!(deserialized.keys.len(), 2); + assert!(deserialized.find_key("1").is_some()); + assert!(deserialized.find_key("2").is_some()); + } + + #[test] + fn test_find_key_success() { + let key = create_test_key(Some("my-key".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let found = set.find_key("my-key"); + assert!(found.is_some()); + assert_eq!(found.unwrap().kid.as_deref(), Some("my-key")); + } + + #[test] + fn test_find_key_missing() { + let key = create_test_key(Some("my-key".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let found = set.find_key("other-key"); + assert!(found.is_none()); + } + + #[test] + fn test_find_key_no_kid() { + let key = create_test_key(None); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let found = set.find_key("any-key"); + assert!(found.is_none()); + } + + #[test] + fn test_find_supported_key() { + let mut sign_key = create_test_key(Some("sign".to_string())); + sign_key.operations = [KeyOperation::Sign].into(); + + let mut verify_key = create_test_key(Some("verify".to_string())); + verify_key.operations = [KeyOperation::Verify].into(); + + let set = KeySet { + keys: vec![Arc::new(sign_key), Arc::new(verify_key)], + }; + + let found_sign = set.find_supported_key(&KeyOperation::Sign); + assert!(found_sign.is_some()); + assert_eq!(found_sign.unwrap().kid.as_deref(), Some("sign")); + + let found_verify = set.find_supported_key(&KeyOperation::Verify); + assert!(found_verify.is_some()); + assert_eq!(found_verify.unwrap().kid.as_deref(), Some("verify")); + } + + #[test] + fn test_to_public_set() { + // Use asymmetric key (ES256) so we can separate public/private + let key = create_test_key(Some("1".to_string())); + + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let public_set = set.to_public_set().expect("failed to convert to public set"); + assert_eq!(public_set.keys.len(), 1); + + let public_key = &public_set.keys[0]; + assert_eq!(public_key.kid.as_deref(), Some("1")); + assert!(public_key.operations.contains(&KeyOperation::Verify)); + assert!(!public_key.operations.contains(&KeyOperation::Sign)); + } + + #[test] + fn test_to_public_set_fails_for_symmetric() { + let key = Key::generate(Algorithm::HS256, Some("sym".to_string())).unwrap(); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let result = set.to_public_set(); + assert!(result.is_err()); + } + + #[test] + fn test_encode_success() { + let key = create_test_key(Some("1".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + let claims = create_test_claims(); + + let token = set.encode(&claims).unwrap(); + assert!(!token.is_empty()); + } + + #[test] + fn test_encode_no_signing_key() { + let mut key = create_test_key(Some("1".to_string())); + key.operations = [KeyOperation::Verify].into(); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + let claims = create_test_claims(); + + let result = set.encode(&claims); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("cannot find signing key")); + } + + #[test] + fn test_decode_success_with_kid() { + let key = create_test_key(Some("1".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + let claims = create_test_claims(); + + let token = set.encode(&claims).unwrap(); + let decoded = set.decode(&token).unwrap(); + + assert_eq!(decoded.root, claims.root); + } + + #[test] + fn test_decode_success_single_key_no_kid() { + // Create a key without KID + let key = create_test_key(None); + let claims = create_test_claims(); + + // Encode using the key directly + let token = key.encode(&claims).unwrap(); + + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + // Decode using the set + let decoded = set.decode(&token).unwrap(); + assert_eq!(decoded.root, claims.root); + } + + #[test] + fn test_decode_fail_multiple_keys_no_kid() { + let key1 = create_test_key(None); + let key2 = create_test_key(None); + + let set = KeySet { + keys: vec![Arc::new(key1), Arc::new(key2)], + }; + + let claims = create_test_claims(); + // Encode with one of the keys directly + let token = set.keys[0].encode(&claims).unwrap(); + + let result = set.decode(&token); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("missing kid")); + } + + #[test] + fn test_decode_fail_unknown_kid() { + let key1 = create_test_key(Some("1".to_string())); + let key2 = create_test_key(Some("2".to_string())); + + let set1 = KeySet { + keys: vec![Arc::new(key1)], + }; + let set2 = KeySet { + keys: vec![Arc::new(key2)], + }; + + let claims = create_test_claims(); + let token = set1.encode(&claims).unwrap(); + + let result = set2.decode(&token); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("cannot find key with kid 1")); + } + + #[test] + fn test_file_io() { + let key = create_test_key(Some("1".to_string())); + let set = KeySet { + keys: vec![Arc::new(key)], + }; + + let dir = std::env::temp_dir(); + // Use a random-ish name to avoid collisions + let filename = format!( + "test_keyset_{}.json", + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_nanos() + ); + let path = dir.join(filename); + + set.to_file(&path).expect("failed to write to file"); + + let loaded = KeySet::from_file(&path).expect("failed to read from file"); + assert_eq!(loaded.keys.len(), 1); + assert_eq!(loaded.keys[0].kid.as_deref(), Some("1")); + + // Clean up + let _ = std::fs::remove_file(path); + } +}