diff --git a/crates/sprout-db/src/channel.rs b/crates/sprout-db/src/channel.rs index 46b2cb5e1..dbf8f0310 100644 --- a/crates/sprout-db/src/channel.rs +++ b/crates/sprout-db/src/channel.rs @@ -522,6 +522,32 @@ pub async fn get_members(pool: &PgPool, channel_id: Uuid) -> Result` ordered by `joined_at`; callers should +/// group by `channel_id` if per-channel access is needed. +/// Returns an empty vec immediately when `channel_ids` is empty. +pub async fn get_members_bulk(pool: &PgPool, channel_ids: &[Uuid]) -> Result> { + if channel_ids.is_empty() { + return Ok(Vec::new()); + } + let rows = sqlx::query( + r#" + SELECT cm.channel_id, cm.pubkey, cm.role::text AS role, cm.joined_at, cm.invited_by, cm.removed_at + FROM channel_members cm + JOIN channels c ON cm.channel_id = c.id AND c.deleted_at IS NULL + WHERE cm.channel_id = ANY($1) AND cm.removed_at IS NULL + ORDER BY cm.joined_at ASC + "#, + ) + .bind(channel_ids) + .fetch_all(pool) + .await?; + rows.into_iter().map(row_to_member_record).collect() +} + /// Get all channel IDs accessible to a pubkey. /// /// Includes channels where the pubkey is an active member AND all open channels. diff --git a/crates/sprout-db/src/lib.rs b/crates/sprout-db/src/lib.rs index 4063cf8a1..b9f1326a0 100644 --- a/crates/sprout-db/src/lib.rs +++ b/crates/sprout-db/src/lib.rs @@ -141,11 +141,14 @@ pub struct DbConfig { } impl Default for DbConfig { + /// Sized for a single relay pod against PG max_connections=100. + /// Staging measured 51 idle + 1 active out of 50 — most connections sat unused. + /// At 20 main + 5 audit = 25/pod, four relay pods fit within the PG limit. fn default() -> Self { Self { database_url: "postgres://sprout:sprout_dev@localhost:5432/sprout".to_string(), - max_connections: 50, - min_connections: 5, + max_connections: 20, + min_connections: 2, acquire_timeout_secs: 3, max_lifetime_secs: 1800, idle_timeout_secs: 600, @@ -398,6 +401,14 @@ impl Db { channel::get_members(&self.pool, channel_id).await } + /// Returns active members for multiple channels in a single query. + pub async fn get_members_bulk( + &self, + channel_ids: &[Uuid], + ) -> Result> { + channel::get_members_bulk(&self.pool, channel_ids).await + } + /// Get all channel IDs accessible to a pubkey. pub async fn get_accessible_channel_ids(&self, pubkey: &[u8]) -> Result> { channel::get_accessible_channel_ids(&self.pool, pubkey).await diff --git a/crates/sprout-relay/src/api/channels.rs b/crates/sprout-relay/src/api/channels.rs index 1aceef693..2047c20ef 100644 --- a/crates/sprout-relay/src/api/channels.rs +++ b/crates/sprout-relay/src/api/channels.rs @@ -64,12 +64,77 @@ pub async fn channels_handler( .await .unwrap_or_default(); + // ── Batch DM participant resolution (2 queries total, not 2×N_DMs) ── + let dm_channel_ids: Vec = channels + .iter() + .filter(|ac| ac.channel.channel_type == "dm") + .map(|ac| ac.channel.id) + .collect(); + + // 1. One query: all members for all DM channels. + let all_dm_members = state + .db + .get_members_bulk(&dm_channel_ids) + .await + .unwrap_or_else(|e| { + tracing::error!("channels: failed to bulk-load DM members: {e}"); + vec![] + }); + + // 2. Collect unique pubkeys across all DM members. + let unique_pubkeys: Vec> = { + let mut seen = std::collections::HashSet::new(); + all_dm_members + .iter() + .filter(|m| seen.insert(m.pubkey.clone())) + .map(|m| m.pubkey.clone()) + .collect() + }; + + // 3. One query: resolve display names for all unique pubkeys. + let user_records = state + .db + .get_users_bulk(&unique_pubkeys) + .await + .unwrap_or_else(|e| { + tracing::error!("channels: failed to bulk-load DM participant profiles: {e}"); + vec![] + }); + let user_map: HashMap = user_records + .into_iter() + .filter_map(|u| { + let hex = nostr_hex::encode(&u.pubkey); + u.display_name.map(|name| (hex, name)) + }) + .collect(); + + // 4. Group members by channel_id for O(1) lookup. + let mut members_by_channel: HashMap> = + HashMap::new(); + for m in &all_dm_members { + members_by_channel.entry(m.channel_id).or_default().push(m); + } + let mut result = Vec::with_capacity(channels.len()); for ac in &channels { let ch = &ac.channel; let (participants, participant_pubkeys) = if ch.channel_type == "dm" { - resolve_dm_participants(&state, ch.id).await + let members = members_by_channel.get(&ch.id); + let mut names = Vec::new(); + let mut pk_hexes = Vec::new(); + if let Some(members) = members { + for m in members { + let hex = nostr_hex::encode(&m.pubkey); + let name = user_map + .get(&hex) + .cloned() + .unwrap_or_else(|| hex[..8.min(hex.len())].to_string()); + names.push(name); + pk_hexes.push(hex); + } + } + (names, pk_hexes) } else { (vec![], vec![]) }; @@ -119,46 +184,3 @@ fn channel_record_to_json( "ttl_deadline": channel.ttl_deadline.map(|t| t.to_rfc3339()), }) } - -/// Fetch DM participants and resolve their display names. -async fn resolve_dm_participants( - state: &AppState, - channel_id: uuid::Uuid, -) -> (Vec, Vec) { - let members = state.db.get_members(channel_id).await.unwrap_or_else(|e| { - tracing::error!("channels: failed to load members for channel {channel_id}: {e}"); - vec![] - }); - - let member_pubkeys: Vec> = members.iter().map(|m| m.pubkey.clone()).collect(); - - let user_records = state - .db - .get_users_bulk(&member_pubkeys) - .await - .unwrap_or_else(|e| { - tracing::error!("channels: failed to load user records for DM participants: {e}"); - vec![] - }); - - let user_map: HashMap = user_records - .into_iter() - .filter_map(|u| { - let hex = nostr_hex::encode(&u.pubkey); - u.display_name.map(|name| (hex, name)) - }) - .collect(); - - let mut names = Vec::new(); - let mut pk_hexes = Vec::new(); - for m in &members { - let hex = nostr_hex::encode(&m.pubkey); - let name = user_map - .get(&hex) - .cloned() - .unwrap_or_else(|| hex[..8.min(hex.len())].to_string()); - names.push(name); - pk_hexes.push(hex); - } - (names, pk_hexes) -} diff --git a/crates/sprout-relay/src/api/dms.rs b/crates/sprout-relay/src/api/dms.rs index 66414160e..b1ab18191 100644 --- a/crates/sprout-relay/src/api/dms.rs +++ b/crates/sprout-relay/src/api/dms.rs @@ -116,6 +116,14 @@ pub async fn open_dm_handler( .map_err(|e| internal_error(&format!("db error: {e}")))?; if was_created { + // Invalidate membership + accessible-channels caches for all participants + // so REQ, /api/feed, and /api/search immediately include the new DM. + // Note: DM hide/unhide does NOT need cache invalidation because + // get_accessible_channel_ids() does not filter on hidden_at. + for pk in &all_bytes { + state.invalidate_membership(channel.id, pk); + } + let actor_hex = nostr_hex::encode(&self_bytes); let participant_hexes: Vec = all_bytes.iter().map(nostr_hex::encode).collect(); if let Err(e) = emit_system_message( @@ -199,8 +207,7 @@ pub async fn add_dm_member_handler( // Verify caller is a member of the existing DM. let is_member = state - .db - .is_member(channel_id, &self_bytes) + .is_member_cached(channel_id, &self_bytes) .await .map_err(|e| internal_error(&format!("db error: {e}")))?; if !is_member { @@ -262,6 +269,12 @@ pub async fn add_dm_member_handler( .map_err(|e| internal_error(&format!("db error: {e}")))?; if was_created { + // Invalidate membership + accessible-channels caches for all participants + // so REQ, /api/feed, and /api/search immediately include the new DM. + for pk in &all_bytes { + state.invalidate_membership(new_channel.id, pk); + } + // Emit NIP-29 group discovery events for the new expanded DM. if let Err(e) = emit_group_discovery_events(&state, new_channel.id).await { tracing::warn!(channel = %new_channel.id, "DM discovery emission failed: {e}"); @@ -393,8 +406,7 @@ pub async fn hide_dm_handler( // Verify caller is a member. let is_member = state - .db - .is_member(channel_id, &ctx.pubkey_bytes) + .is_member_cached(channel_id, &ctx.pubkey_bytes) .await .map_err(|e| internal_error(&format!("db error: {e}")))?; diff --git a/crates/sprout-relay/src/api/feed.rs b/crates/sprout-relay/src/api/feed.rs index df7d26ecb..633f7aa6d 100644 --- a/crates/sprout-relay/src/api/feed.rs +++ b/crates/sprout-relay/src/api/feed.rs @@ -73,8 +73,7 @@ pub async fn feed_handler( let accessible_ids = constrain_channel_ids( state - .db - .get_accessible_channel_ids(&pubkey_bytes) + .get_accessible_channel_ids_cached(&pubkey_bytes) .await .map_err(|e| internal_error(&format!("db error: {e}")))?, ctx.channel_ids.as_deref(), diff --git a/crates/sprout-relay/src/api/media.rs b/crates/sprout-relay/src/api/media.rs index 8dc83c712..20ecf02bd 100644 --- a/crates/sprout-relay/src/api/media.rs +++ b/crates/sprout-relay/src/api/media.rs @@ -159,26 +159,28 @@ pub async fn upload_blob( }; metrics::counter!("sprout_media_uploads_total", "mime" => mime_label.to_owned()).increment(1); - // Fire-and-forget audit — never block the response on audit I/O. - let audit = state.audit.clone(); + // Audit via bounded channel — same pattern as event audit. let desc = descriptor.clone(); let uploader = auth.auth_event.pubkey.to_hex(); - tokio::spawn(async move { - let _ = audit - .log(NewAuditEntry { - event_id: desc.sha256.clone(), - event_kind: sprout_core::kind::KIND_MEDIA_UPLOAD, - actor_pubkey: uploader, - action: AuditAction::MediaUploaded, - channel_id: None, - metadata: serde_json::json!({ - "sha256": desc.sha256, - "size": desc.size, - "mime": desc.mime_type, - }), - }) - .await; - }); + if let Err(e) = state + .audit_tx + .send(NewAuditEntry { + event_id: desc.sha256.clone(), + event_kind: sprout_core::kind::KIND_MEDIA_UPLOAD, + actor_pubkey: uploader, + action: AuditAction::MediaUploaded, + channel_id: None, + metadata: serde_json::json!({ + "sha256": desc.sha256, + "size": desc.size, + "mime": desc.mime_type, + }), + }) + .await + { + tracing::error!("Media audit channel closed — entry lost: {e}"); + metrics::counter!("sprout_audit_send_errors_total").increment(1); + } Ok(Json(descriptor)) } diff --git a/crates/sprout-relay/src/api/mod.rs b/crates/sprout-relay/src/api/mod.rs index 2081d0cd0..3167d06a8 100644 --- a/crates/sprout-relay/src/api/mod.rs +++ b/crates/sprout-relay/src/api/mod.rs @@ -520,8 +520,7 @@ pub(crate) async fn check_channel_membership( pubkey_bytes: &[u8], ) -> Result<(), (StatusCode, Json)> { let is_member = state - .db - .is_member(channel_id, pubkey_bytes) + .is_member_cached(channel_id, pubkey_bytes) .await .map_err(|e| internal_error(&format!("db error: {e}")))?; if is_member { diff --git a/crates/sprout-relay/src/api/search.rs b/crates/sprout-relay/src/api/search.rs index e066cd04a..34d5006c3 100644 --- a/crates/sprout-relay/src/api/search.rs +++ b/crates/sprout-relay/src/api/search.rs @@ -44,8 +44,7 @@ pub async fn search_handler( let channel_ids = constrain_channel_ids( state - .db - .get_accessible_channel_ids(&pubkey_bytes) + .get_accessible_channel_ids_cached(&pubkey_bytes) .await .unwrap_or_default(), ctx.channel_ids.as_deref(), diff --git a/crates/sprout-relay/src/api/tokens.rs b/crates/sprout-relay/src/api/tokens.rs index 6cd435b67..275017d33 100644 --- a/crates/sprout-relay/src/api/tokens.rs +++ b/crates/sprout-relay/src/api/tokens.rs @@ -473,8 +473,7 @@ pub async fn post_tokens( // Verify caller is a member of the channel. let is_member = state - .db - .is_member(cid, &ctx.pubkey_bytes) + .is_member_cached(cid, &ctx.pubkey_bytes) .await .map_err(|e| internal_error(&format!("db error: {e}")))?; if !is_member { diff --git a/crates/sprout-relay/src/audio/handler.rs b/crates/sprout-relay/src/audio/handler.rs index 055df4b78..33547bf11 100644 --- a/crates/sprout-relay/src/audio/handler.rs +++ b/crates/sprout-relay/src/audio/handler.rs @@ -512,8 +512,7 @@ async fn ensure_membership( // Fast path: already a member. let is_member = state - .db - .is_member(channel_id, pubkey_bytes) + .is_member_cached(channel_id, pubkey_bytes) .await .map_err(|e| format!("db error: {e}"))?; @@ -537,8 +536,7 @@ async fn ensure_membership( if channel.ttl_seconds.is_some() { if let Some(parent_id) = parent_channel_id { let parent_member = state - .db - .is_member(parent_id, pubkey_bytes) + .is_member_cached(parent_id, pubkey_bytes) .await .map_err(|e| format!("db error: {e}"))?; @@ -553,6 +551,7 @@ async fn ensure_membership( ) .await .map_err(|e| format!("auto-add failed: {e}"))?; + state.invalidate_membership(channel_id, pubkey_bytes); return Ok(()); } diff --git a/crates/sprout-relay/src/handlers/event.rs b/crates/sprout-relay/src/handlers/event.rs index d7ddd91a4..7725269d6 100644 --- a/crates/sprout-relay/src/handlers/event.rs +++ b/crates/sprout-relay/src/handlers/event.rs @@ -61,6 +61,7 @@ pub(crate) async fn dispatch_persistent_event( } let matches = state.sub_registry.fan_out(stored_event); + metrics::histogram!("sprout_fanout_recipients").record(matches.len() as f64); debug!( event_id = %event_id_hex, channel_id = ?stored_event.channel_id, @@ -96,26 +97,25 @@ pub(crate) async fn dispatch_persistent_event( warn!(event_id = %event_id_hex, "Search index channel full — dropping event"); } - let audit = Arc::clone(&state.audit); - let audit_event_id = event_id_hex.clone(); - let audit_actor_pubkey = actor_pubkey_hex.to_string(); - let audit_channel_id = stored_event.channel_id; - tokio::spawn(async move { - let entry = sprout_audit::NewAuditEntry { - event_id: audit_event_id.clone(), - event_kind: kind_u32, - actor_pubkey: audit_actor_pubkey, - action: sprout_audit::AuditAction::EventCreated, - channel_id: audit_channel_id, - metadata: serde_json::Value::Null, - }; - let t = std::time::Instant::now(); - if let Err(e) = audit.log(entry).await { - error!(event_id = %audit_event_id, "Audit log failed: {e}"); - } else { - metrics::histogram!("sprout_audit_log_seconds").record(t.elapsed().as_secs_f64()); - } - }); + // Audit via bounded channel (capacity 1000). Uses .send().await so entries + // are never silently dropped — backpressure propagates to the event handler + // if the queue is full. This is intentional: the audit advisory lock already + // serializes writes (at most 1 in-flight), so a full queue means the audit + // DB is genuinely overloaded and the relay should slow down rather than + // accumulate unbounded in-memory state. DB write failures in the worker are + // logged but not retried (same as the previous per-event tokio::spawn). + let audit_entry = sprout_audit::NewAuditEntry { + event_id: event_id_hex.clone(), + event_kind: kind_u32, + actor_pubkey: actor_pubkey_hex.to_string(), + action: sprout_audit::AuditAction::EventCreated, + channel_id: stored_event.channel_id, + metadata: serde_json::Value::Null, + }; + if let Err(e) = state.audit_tx.send(audit_entry).await { + error!(event_id = %event_id_hex, "Audit channel closed — entry lost: {e}"); + metrics::counter!("sprout_audit_send_errors_total").increment(1); + } // Skip workflow triggering for workflow-execution kinds and relay-signed workflow messages. let is_relay_workflow_msg = stored_event.event.pubkey == state.relay_keypair.public_key() @@ -332,6 +332,7 @@ async fn handle_ephemeral_event( let stored_event = StoredEvent::new(event.clone(), None); let matches = state.sub_registry.fan_out(&stored_event); + metrics::histogram!("sprout_fanout_recipients").record(matches.len() as f64); let event_json = serde_json::to_string(&event) .expect("nostr::Event serialization is infallible for well-formed events"); let mut drop_count = 0u32; @@ -375,6 +376,7 @@ async fn handle_ephemeral_event( // Pass the channel_id so fan_out() uses the channel-kind index. let stored_event = StoredEvent::new(event.clone(), Some(ch_id)); let matches = state.sub_registry.fan_out(&stored_event); + metrics::histogram!("sprout_fanout_recipients").record(matches.len() as f64); let event_json = serde_json::to_string(&event) .expect("nostr::Event serialization is infallible for well-formed events"); let mut drop_count = 0u32; @@ -411,6 +413,7 @@ async fn handle_ephemeral_event( // Pass channel_id=None so fan_out() uses the global subscriber index. let stored_event = StoredEvent::new(event.clone(), None); let matches = state.sub_registry.fan_out(&stored_event); + metrics::histogram!("sprout_fanout_recipients").record(matches.len() as f64); let event_json = serde_json::to_string(&event) .expect("nostr::Event serialization is infallible for well-formed events"); let mut drop_count = 0u32; diff --git a/crates/sprout-relay/src/handlers/ingest.rs b/crates/sprout-relay/src/handlers/ingest.rs index d6ddfb9fa..60c5f8c83 100644 --- a/crates/sprout-relay/src/handlers/ingest.rs +++ b/crates/sprout-relay/src/handlers/ingest.rs @@ -316,7 +316,7 @@ pub(crate) async fn check_channel_membership( ch_id: Uuid, pubkey_bytes: &[u8], ) -> Result<(), String> { - match state.db.is_member(ch_id, pubkey_bytes).await { + match state.is_member_cached(ch_id, pubkey_bytes).await { Ok(true) => return Ok(()), Ok(false) => {} Err(e) => return Err(format!("error: database error: {e}")), @@ -1310,6 +1310,7 @@ pub async fn ingest_event( if let Err(re) = state.db.soft_delete_channel(ch_id).await { warn!(event_id = %event_id_hex, "channel compensation failed: {re}"); } + state.invalidate_channel_deleted(); } return Err(match e { sprout_db::DbError::AuthEventRejected => { diff --git a/crates/sprout-relay/src/handlers/req.rs b/crates/sprout-relay/src/handlers/req.rs index c9c12e257..f6ba2a0a6 100644 --- a/crates/sprout-relay/src/handlers/req.rs +++ b/crates/sprout-relay/src/handlers/req.rs @@ -68,7 +68,8 @@ pub async fn handle_req( } }; - let mut accessible_channels = match state.db.get_accessible_channel_ids(&pubkey_bytes).await { + let mut accessible_channels = match state.get_accessible_channel_ids_cached(&pubkey_bytes).await + { Ok(ids) => ids, Err(e) => { warn!(conn_id = %conn_id, "Failed to get accessible channels: {e}"); diff --git a/crates/sprout-relay/src/handlers/side_effects.rs b/crates/sprout-relay/src/handlers/side_effects.rs index 24efec886..39a09932b 100644 --- a/crates/sprout-relay/src/handlers/side_effects.rs +++ b/crates/sprout-relay/src/handlers/side_effects.rs @@ -301,7 +301,7 @@ pub async fn validate_admin_event( } } else { // topic/purpose: any member - let is_member = state.db.is_member(channel_id, &actor_bytes).await?; + let is_member = state.is_member_cached(channel_id, &actor_bytes).await?; if is_member { Ok(()) } else { @@ -698,6 +698,7 @@ async fn handle_put_user(event: &Event, state: &Arc) -> anyhow::Result .db .add_member(channel_id, &target_pubkey, role, Some(&actor_bytes)) .await?; + state.invalidate_membership(channel_id, &target_pubkey); let actor_hex = nostr::util::hex::encode(&actor_bytes); let target_hex = nostr::util::hex::encode(&target_pubkey); @@ -756,6 +757,7 @@ async fn handle_remove_user(event: &Event, state: &Arc) -> anyhow::Res .db .remove_member(channel_id, &target_pubkey, &actor_bytes) .await?; + state.invalidate_membership(channel_id, &target_pubkey); evict_live_channel_subscriptions(state, channel_id, &target_pubkey).await; let actor_hex = nostr::util::hex::encode(&actor_bytes); @@ -1033,6 +1035,14 @@ async fn handle_create_group(event: &Event, state: &Arc) -> anyhow::Re .await? }; + // Creator becomes owner — evict any stale negative membership lookup. + state.invalidate_membership(channel.id, &actor_bytes); + // Open channels appear in everyone's accessible set; private channels only + // affect the creator (the sole initial member). + if visibility == sprout_db::channel::ChannelVisibility::Open { + state.invalidate_all_accessible_channels(); + } + let actor_hex = nostr::util::hex::encode(&actor_bytes); emit_system_message( state, @@ -1088,6 +1098,10 @@ async fn handle_delete_group(event: &Event, state: &Arc) -> anyhow::Re warn!(channel = %channel_id, error = %e, "failed to clean up NIP-29 discovery events"); } + // Deleted channel: clear both membership and accessible-channels caches. + // Stale is_member=true entries would bypass the DB's deleted_at guard. + state.invalidate_channel_deleted(); + let actor_hex = nostr::util::hex::encode(&actor_bytes); emit_system_message( state, @@ -1121,7 +1135,7 @@ async fn handle_join_request(event: &Event, state: &Arc) -> anyhow::Re // Skip if already an active member — prevents duplicate join notifications. // Fail closed on DB errors rather than falling through to add_member. - if state.db.is_member(channel_id, &actor_bytes).await? { + if state.is_member_cached(channel_id, &actor_bytes).await? { info!(channel = %channel_id, "kind:9021 join — already a member, skipping"); return Ok(()); } @@ -1136,6 +1150,7 @@ async fn handle_join_request(event: &Event, state: &Arc) -> anyhow::Re None, ) .await?; + state.invalidate_membership(channel_id, &actor_bytes); let actor_hex = nostr::util::hex::encode(&actor_bytes); emit_system_message( @@ -1191,6 +1206,7 @@ async fn handle_leave_request(event: &Event, state: &Arc) -> anyhow::R .db .remove_member(channel_id, &actor_bytes, &actor_bytes) .await?; + state.invalidate_membership(channel_id, &actor_bytes); evict_live_channel_subscriptions(state, channel_id, &actor_bytes).await; let actor_hex = nostr::util::hex::encode(&actor_bytes); diff --git a/crates/sprout-relay/src/main.rs b/crates/sprout-relay/src/main.rs index 64c1ac2cf..0ecbfc954 100644 --- a/crates/sprout-relay/src/main.rs +++ b/crates/sprout-relay/src/main.rs @@ -67,7 +67,10 @@ async fn main() -> anyhow::Result<()> { Err(e) => error!("Failed to backfill d_tags: {e}"), } - let audit_pool = sqlx::PgPool::connect(&config.database_url) + let audit_pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(5) + .min_connections(1) + .connect(&config.database_url) .await .map_err(|e| anyhow::anyhow!("Audit DB connection failed: {e}"))?; let audit = AuditService::new(audit_pool); @@ -127,7 +130,7 @@ async fn main() -> anyhow::Result<()> { .map_err(|e| anyhow::anyhow!("failed to initialize media storage: {e}"))?; info!("Media storage connected"); - let state = Arc::new(AppState::new( + let (app_state, audit_shutdown) = AppState::new( config.clone(), db, redis_health_pool, @@ -138,7 +141,8 @@ async fn main() -> anyhow::Result<()> { Arc::clone(&workflow_engine), relay_keypair, media_storage, - )); + ); + let state = Arc::new(app_state); // Wire the action sink — must happen after AppState (which creates // sub_registry, conn_manager) and before the cron loop starts. @@ -287,7 +291,17 @@ async fn main() -> anyhow::Result<()> { let router = build_router(Arc::clone(&state)); let health_router = build_health_router(Arc::clone(&state)); - serve(router, health_router, Arc::clone(&state)).await + serve(router, health_router, Arc::clone(&state)).await?; + + // ── Drain audit queue ──────────────────────────────────────────────────── + // Signal the audit worker to stop accepting, flush buffered entries, and + // exit. Uses a CancellationToken so it works regardless of how many + // Arc clones are still alive in background tasks. + audit_shutdown + .drain(std::time::Duration::from_secs(5)) + .await; + + Ok(()) } /// Bind all listeners and run with graceful shutdown. diff --git a/crates/sprout-relay/src/state.rs b/crates/sprout-relay/src/state.rs index 9a78bf09a..57836e065 100644 --- a/crates/sprout-relay/src/state.rs +++ b/crates/sprout-relay/src/state.rs @@ -6,7 +6,9 @@ use std::time::Instant; use axum::extract::ws::Message as WsMessage; use dashmap::DashMap; -use tokio::sync::{mpsc, Semaphore}; +use tokio::sync::mpsc; +use tokio::sync::Semaphore; +use tokio::task::JoinHandle; use tokio_util::sync::CancellationToken; use uuid::Uuid; @@ -196,11 +198,19 @@ pub struct AppState { pub local_event_ids: Arc>, /// Membership cache: (channel_id, pubkey_bytes) → is_member. /// Short TTL (10s) — membership changes are rare but must propagate. + /// Multi-pod: other pods rely on TTL expiry; only local caches are invalidated. pub membership_cache: Arc), bool>>, + /// Accessible channel IDs cache: pubkey_bytes → channel UUIDs. + /// Short TTL (10s) — invalidated on membership or channel visibility changes. + /// Multi-pod: other pods rely on TTL expiry; only local caches are invalidated. + pub accessible_channels_cache: Arc, Vec>>, /// Bounded channel for search indexing — prevents OOM if Typesense is slow/down. /// Capacity 1000: at ~1KB/event that's ~1MB of backlog before we start dropping. pub search_index_tx: mpsc::Sender, + /// Bounded channel for audit logging — backpressure instead of unbounded spawns. + /// Uses .send().await (blocks caller if full) because audit entries must not be lost. + pub audit_tx: mpsc::Sender, /// Media storage client (S3/MinIO). pub media_storage: Arc, /// Audio relay room manager — tracks active huddle audio rooms. @@ -213,6 +223,10 @@ pub struct AppState { impl AppState { /// Constructs `AppState` from its component services. + /// + /// Returns `(state, audit_shutdown)`. The caller should call + /// `audit_shutdown.drain().await` during graceful shutdown so queued + /// audit entries are flushed before the process exits. #[allow(clippy::too_many_arguments)] pub fn new( config: Config, @@ -225,7 +239,7 @@ impl AppState { workflow_engine: Arc, relay_keypair: nostr::Keys, media_storage: MediaStorage, - ) -> Self { + ) -> (Self, AuditShutdownHandle) { let max_connections = config.max_connections; let max_concurrent_handlers = config.max_concurrent_handlers; let search_arc = Arc::new(search); @@ -252,11 +266,46 @@ impl AppState { tracing::warn!("search index worker exited (expected on shutdown)"); }); - Self { + let audit_arc = Arc::new(audit); + let (audit_tx, mut audit_rx) = mpsc::channel::(1000); + let audit_for_worker = Arc::clone(&audit_arc); + let audit_cancel = CancellationToken::new(); + let audit_cancel_worker = audit_cancel.clone(); + let audit_worker_handle = tokio::spawn(async move { + // Normal operation: process entries as they arrive. + loop { + tokio::select! { + entry = audit_rx.recv() => { + match entry { + Some(entry) => log_audit_entry(&audit_for_worker, entry).await, + None => break, // channel closed + } + } + _ = audit_cancel_worker.cancelled() => { + // Close the receiver: rejects future sends and lets us + // drain everything already buffered without a race. + audit_rx.close(); + break; + } + } + } + // Drain: recv() returns buffered entries, then None once empty. + let mut drained = 0u32; + while let Some(entry) = audit_rx.recv().await { + log_audit_entry(&audit_for_worker, entry).await; + drained += 1; + } + if drained > 0 { + tracing::info!(drained, "audit worker flushed remaining entries"); + } + tracing::warn!("audit log worker exited (expected on shutdown)"); + }); + + let state = Self { config: Arc::new(config), db, redis_pool, - audit: Arc::new(audit), + audit: audit_arc, pubsub, auth: Arc::new(auth), search: search_arc, @@ -280,13 +329,27 @@ impl AppState { .time_to_live(std::time::Duration::from_secs(10)) .build(), ), + accessible_channels_cache: Arc::new( + moka::sync::Cache::builder() + .max_capacity(10_000) + .time_to_live(std::time::Duration::from_secs(10)) + .build(), + ), search_index_tx, + audit_tx, media_storage: Arc::new(media_storage), audio_rooms: Arc::new(AudioRoomManager::new()), shutting_down: Arc::new(AtomicBool::new(false)), started_at: Instant::now(), - } + }; + ( + state, + AuditShutdownHandle { + cancel: audit_cancel, + handle: audit_worker_handle, + }, + ) } /// Record an event ID as locally-published for dedup. @@ -294,6 +357,99 @@ impl AppState { pub fn mark_local_event(&self, event_id: &nostr::EventId) { self.local_event_ids.insert(event_id.to_bytes(), ()); } + + /// Check channel membership with a 10-second cache. Falls back to DB on miss. + pub async fn is_member_cached( + &self, + channel_id: Uuid, + pubkey: &[u8], + ) -> Result { + let key = (channel_id, pubkey.to_vec()); + if let Some(cached) = self.membership_cache.get(&key) { + metrics::counter!("sprout_membership_cache_hits_total").increment(1); + return Ok(cached); + } + metrics::counter!("sprout_membership_cache_misses_total").increment(1); + let result = self.db.is_member(channel_id, pubkey).await?; + self.membership_cache.insert(key, result); + Ok(result) + } + + /// Invalidate caches after a membership change (add/remove member). + pub fn invalidate_membership(&self, channel_id: Uuid, pubkey: &[u8]) { + self.membership_cache + .invalidate(&(channel_id, pubkey.to_vec())); + self.accessible_channels_cache.invalidate(&pubkey.to_vec()); + } + + /// Invalidate all users' accessible-channels cache (e.g. new open channel created). + pub fn invalidate_all_accessible_channels(&self) { + self.accessible_channels_cache.invalidate_all(); + } + + /// Invalidate all caches after a channel is deleted. + /// + /// Channel deletion is a rare admin operation. We clear the entire membership + /// cache because moka doesn't support prefix-based invalidation on composite + /// keys, and stale `is_member=true` entries for a deleted channel would bypass + /// the DB's `deleted_at IS NULL` guard. + pub fn invalidate_channel_deleted(&self) { + self.membership_cache.invalidate_all(); + self.accessible_channels_cache.invalidate_all(); + } + + /// Get accessible channel IDs with a 10-second cache. Falls back to DB on miss. + pub async fn get_accessible_channel_ids_cached( + &self, + pubkey: &[u8], + ) -> Result, sprout_db::DbError> { + let key = pubkey.to_vec(); + if let Some(cached) = self.accessible_channels_cache.get(&key) { + metrics::counter!("sprout_accessible_channels_cache_hits_total").increment(1); + return Ok(cached); + } + metrics::counter!("sprout_accessible_channels_cache_misses_total").increment(1); + let result = self.db.get_accessible_channel_ids(pubkey).await?; + self.accessible_channels_cache.insert(key, result.clone()); + Ok(result) + } +} + +/// Handle for graceful audit worker shutdown. +/// +/// Signals the worker to stop accepting new entries, drain its buffer, +/// and exit. Independent of `Arc` lifetime — works even when +/// background tasks (reaper, pubsub, health) still hold state clones. +pub struct AuditShutdownHandle { + cancel: CancellationToken, + handle: JoinHandle<()>, +} + +impl AuditShutdownHandle { + /// Signal the audit worker to drain and wait up to `timeout` for it to finish. + pub async fn drain(self, timeout: std::time::Duration) { + self.cancel.cancel(); + match tokio::time::timeout(timeout, self.handle).await { + Ok(Ok(())) => tracing::info!("Audit worker drained cleanly"), + Ok(Err(e)) => tracing::error!("Audit worker panicked: {e}"), + Err(_) => tracing::error!( + ?timeout, + "Audit worker did not drain in time — exiting anyway" + ), + } + } +} + +/// Log a single audit entry with metrics. Extracted so the normal loop +/// and the post-cancel drain share the same logic. +async fn log_audit_entry(audit: &sprout_audit::AuditService, entry: sprout_audit::NewAuditEntry) { + let t = std::time::Instant::now(); + if let Err(e) = audit.log(entry).await { + metrics::counter!("sprout_audit_log_errors_total").increment(1); + tracing::error!("Audit log failed: {e}"); + } else { + metrics::histogram!("sprout_audit_log_seconds").record(t.elapsed().as_secs_f64()); + } } impl std::fmt::Debug for AppState { diff --git a/crates/sprout-relay/src/subscription.rs b/crates/sprout-relay/src/subscription.rs index e204aa5d5..084ed1a31 100644 --- a/crates/sprout-relay/src/subscription.rs +++ b/crates/sprout-relay/src/subscription.rs @@ -33,6 +33,10 @@ pub struct SubscriptionRegistry { channel_kind_index: DashMap>, /// Subscriptions with a channel_id but no kind filter — need to receive ALL kinds. channel_wildcard_index: DashMap>, + /// Global subscriptions indexed by kind — avoids O(all_subs) scan for global events. + global_kind_index: DashMap>, + /// Global subscriptions with no kind filter — wildcard, receives all global events. + global_wildcard_index: DashMap<(), Vec<(ConnId, SubId)>>, } impl SubscriptionRegistry { @@ -86,6 +90,25 @@ impl SubscriptionRegistry { } } } + } else { + // Global subscription — index by kind for sub-linear fan-out. + match extract_kinds_from_filters(&filters) { + None => { + self.global_wildcard_index + .entry(()) + .or_default() + .push((conn_id, sub_id.clone())); + } + Some(kinds) if kinds.is_empty() => {} + Some(kinds) => { + for kind in kinds { + self.global_kind_index + .entry(kind) + .or_default() + .push((conn_id, sub_id.clone())); + } + } + } } } @@ -165,18 +188,29 @@ impl SubscriptionRegistry { } } } else { - // Global event (channel_id = None) — only deliver to global subscriptions. - // Channel-scoped subscriptions are skipped: they target a specific channel - // and should not receive global infrastructure events (e.g. membership - // notifications) even if tag matching would succeed. - for conn_entry in self.subs.iter() { - let conn_id = *conn_entry.key(); - for (sub_id, (filters, sub_channel_id)) in conn_entry.value().iter() { - if sub_channel_id.is_some() { - continue; // skip channel-scoped subscriptions + // Global event (channel_id = None) — use global indexes for sub-linear fan-out. + // Channel-scoped subscriptions are never in these indexes, preserving the + // scoping invariant without an explicit skip check. + if let Some(candidates) = self.global_kind_index.get(&event.event.kind) { + for (conn_id, sub_id) in candidates.iter() { + if let Some(conn_subs) = self.subs.get(conn_id) { + if let Some((filters, _)) = conn_subs.get(sub_id.as_str()) { + if filters_match(filters, event) { + results.push((*conn_id, sub_id.clone())); + } + } } - if filters_match(filters, event) { - results.push((conn_id, sub_id.clone())); + } + } + // Also check global wildcard (kindless global subs). + if let Some(wildcards) = self.global_wildcard_index.get(&()) { + for (conn_id, sub_id) in wildcards.iter() { + if let Some(conn_subs) = self.subs.get(conn_id) { + if let Some((filters, _)) = conn_subs.get(sub_id.as_str()) { + if filters_match(filters, event) { + results.push((*conn_id, sub_id.clone())); + } + } } } } @@ -255,8 +289,32 @@ impl SubscriptionRegistry { } } } + } else { + // Global subscription — remove from global indexes. + match extract_kinds_from_filters(filters) { + None => { + if let Some(mut entries) = self.global_wildcard_index.get_mut(&()) { + entries.retain(|(cid, sid)| !(*cid == conn_id && sid == sub_id)); + if entries.is_empty() { + drop(entries); + self.global_wildcard_index.remove(&()); + } + } + } + Some(kinds) if kinds.is_empty() => {} + Some(kinds) => { + for kind in kinds { + if let Some(mut entries) = self.global_kind_index.get_mut(&kind) { + entries.retain(|(cid, sid)| !(*cid == conn_id && sid == sub_id)); + if entries.is_empty() { + drop(entries); + self.global_kind_index.remove(&kind); + } + } + } + } + } } - // If no channel_id, there's nothing in the index to remove (slow-path subs aren't indexed) } } @@ -767,4 +825,131 @@ mod tests { assert_eq!(matches_b.len(), 1); assert_eq!(matches_b[0].1, "sub-b"); } + + #[test] + fn test_global_kind_index_fan_out() { + // Global subscriptions with explicit kinds should use the global_kind_index + // for sub-linear fan-out instead of scanning all subs. + let registry = SubscriptionRegistry::new(); + let conn_a = Uuid::new_v4(); + let conn_b = Uuid::new_v4(); + + registry.register( + conn_a, + "global_text".to_string(), + vec![Filter::new().kind(Kind::TextNote)], + None, + ); + registry.register( + conn_b, + "global_meta".to_string(), + vec![Filter::new().kind(Kind::Metadata)], + None, + ); + + let event_text = make_stored_event(Kind::TextNote, None); + let matches = registry.fan_out(&event_text); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].0, conn_a); + + let event_meta = make_stored_event(Kind::Metadata, None); + let matches = registry.fan_out(&event_meta); + assert_eq!(matches.len(), 1); + assert_eq!(matches[0].0, conn_b); + + // Unrelated kind matches nobody. + let event_custom = make_stored_event(Kind::Custom(9999), None); + assert!(registry.fan_out(&event_custom).is_empty()); + } + + #[test] + fn test_global_wildcard_index_fan_out() { + // A global subscription with no kind filter should receive all global events. + let registry = SubscriptionRegistry::new(); + let conn_id = Uuid::new_v4(); + + registry.register( + conn_id, + "global_wildcard".to_string(), + vec![Filter::new()], // kindless + None, + ); + + let event_text = make_stored_event(Kind::TextNote, None); + let matches = registry.fan_out(&event_text); + assert_eq!(matches.len(), 1); + + let event_meta = make_stored_event(Kind::Metadata, None); + let matches = registry.fan_out(&event_meta); + assert_eq!(matches.len(), 1); + + // Must NOT receive channel-scoped events. + let channel_event = make_stored_event(Kind::TextNote, Some(Uuid::new_v4())); + assert!(registry.fan_out(&channel_event).is_empty()); + } + + #[test] + fn test_global_index_removal_cleanup() { + // Removing a global subscription should clean up the global indexes. + let registry = SubscriptionRegistry::new(); + let conn_id = Uuid::new_v4(); + + // Kind-specific global sub. + registry.register( + conn_id, + "g1".to_string(), + vec![Filter::new().kind(Kind::TextNote)], + None, + ); + assert!(registry.global_kind_index.get(&Kind::TextNote).is_some()); + + registry.remove_subscription(conn_id, "g1"); + assert!(registry.global_kind_index.get(&Kind::TextNote).is_none()); + + // Wildcard global sub. + registry.register(conn_id, "g2".to_string(), vec![Filter::new()], None); + assert!(registry.global_wildcard_index.get(&()).is_some()); + + registry.remove_subscription(conn_id, "g2"); + assert!(registry.global_wildcard_index.get(&()).is_none()); + } + + #[test] + fn test_global_and_channel_subs_are_isolated() { + // Global subs must not see channel events; channel subs must not see global events. + // This tests the invariant with the new global index in place. + let registry = SubscriptionRegistry::new(); + let conn_global = Uuid::new_v4(); + let conn_channel = Uuid::new_v4(); + let channel_id = Uuid::new_v4(); + + registry.register( + conn_global, + "global".to_string(), + vec![Filter::new().kind(Kind::TextNote)], + None, + ); + registry.register( + conn_channel, + "channel".to_string(), + vec![Filter::new().kind(Kind::TextNote)], + Some(channel_id), + ); + + let global_event = make_stored_event(Kind::TextNote, None); + let matches = registry.fan_out(&global_event); + assert_eq!(matches.len(), 1); + assert_eq!( + matches[0].0, conn_global, + "only global sub sees global event" + ); + + let channel_event = make_stored_event(Kind::TextNote, Some(channel_id)); + let matches = registry.fan_out(&channel_event); + assert_eq!(matches.len(), 1); + assert_eq!( + matches[0].0, conn_channel, + "only channel sub sees channel event" + ); + } } diff --git a/crates/sprout-relay/src/workflow_sink.rs b/crates/sprout-relay/src/workflow_sink.rs index 325dd282f..b95a7b0ef 100644 --- a/crates/sprout-relay/src/workflow_sink.rs +++ b/crates/sprout-relay/src/workflow_sink.rs @@ -90,8 +90,7 @@ impl ActionSink for RelayActionSink { let author_pubkey_bytes = author_pubkey.serialize().to_vec(); let author_pubkey_hex = author_pubkey.to_hex(); let is_member = state - .db - .is_member(channel_uuid, &author_pubkey_bytes) + .is_member_cached(channel_uuid, &author_pubkey_bytes) .await .map_err(|e| ActionSinkError::Database(e.to_string()))?; if !is_member && channel.visibility != "open" {