Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions crates/sprout-db/src/channel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,32 @@ pub async fn get_members(pool: &PgPool, channel_id: Uuid) -> Result<Vec<MemberRe
rows.into_iter().map(row_to_member_record).collect()
}

/// Returns active members for multiple channels in a single query.
///
/// Designed for small-batch use (e.g. DM participant resolution where each
/// channel has 2-9 members). For large channel sets, consider pagination.
/// Returns a flat `Vec<MemberRecord>` ordered by `joined_at`; callers should
/// group by `channel_id` if per-channel access is needed.
/// Returns an empty vec immediately when `channel_ids` is empty.
pub async fn get_members_bulk(pool: &PgPool, channel_ids: &[Uuid]) -> Result<Vec<MemberRecord>> {
if channel_ids.is_empty() {
return Ok(Vec::new());
}
let rows = sqlx::query(
r#"
SELECT cm.channel_id, cm.pubkey, cm.role::text AS role, cm.joined_at, cm.invited_by, cm.removed_at
FROM channel_members cm
JOIN channels c ON cm.channel_id = c.id AND c.deleted_at IS NULL
WHERE cm.channel_id = ANY($1) AND cm.removed_at IS NULL
ORDER BY cm.joined_at ASC
"#,
)
.bind(channel_ids)
.fetch_all(pool)
.await?;
rows.into_iter().map(row_to_member_record).collect()
}

/// Get all channel IDs accessible to a pubkey.
///
/// Includes channels where the pubkey is an active member AND all open channels.
Expand Down
15 changes: 13 additions & 2 deletions crates/sprout-db/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,11 +141,14 @@ pub struct DbConfig {
}

impl Default for DbConfig {
/// Sized for a single relay pod against PG max_connections=100.
/// Staging measured 51 idle + 1 active out of 50 — most connections sat unused.
/// At 20 main + 5 audit = 25/pod, four relay pods fit within the PG limit.
fn default() -> Self {
Self {
database_url: "postgres://sprout:sprout_dev@localhost:5432/sprout".to_string(),
max_connections: 50,
min_connections: 5,
max_connections: 20,
min_connections: 2,
acquire_timeout_secs: 3,
max_lifetime_secs: 1800,
idle_timeout_secs: 600,
Expand Down Expand Up @@ -398,6 +401,14 @@ impl Db {
channel::get_members(&self.pool, channel_id).await
}

/// Returns active members for multiple channels in a single query.
pub async fn get_members_bulk(
&self,
channel_ids: &[Uuid],
) -> Result<Vec<channel::MemberRecord>> {
channel::get_members_bulk(&self.pool, channel_ids).await
}

/// Get all channel IDs accessible to a pubkey.
pub async fn get_accessible_channel_ids(&self, pubkey: &[u8]) -> Result<Vec<Uuid>> {
channel::get_accessible_channel_ids(&self.pool, pubkey).await
Expand Down
110 changes: 66 additions & 44 deletions crates/sprout-relay/src/api/channels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,77 @@ pub async fn channels_handler(
.await
.unwrap_or_default();

// ── Batch DM participant resolution (2 queries total, not 2×N_DMs) ──
let dm_channel_ids: Vec<uuid::Uuid> = channels
.iter()
.filter(|ac| ac.channel.channel_type == "dm")
.map(|ac| ac.channel.id)
.collect();

// 1. One query: all members for all DM channels.
let all_dm_members = state
.db
.get_members_bulk(&dm_channel_ids)
.await
.unwrap_or_else(|e| {
tracing::error!("channels: failed to bulk-load DM members: {e}");
vec![]
});

// 2. Collect unique pubkeys across all DM members.
let unique_pubkeys: Vec<Vec<u8>> = {
let mut seen = std::collections::HashSet::new();
all_dm_members
.iter()
.filter(|m| seen.insert(m.pubkey.clone()))
.map(|m| m.pubkey.clone())
.collect()
};

// 3. One query: resolve display names for all unique pubkeys.
let user_records = state
.db
.get_users_bulk(&unique_pubkeys)
.await
.unwrap_or_else(|e| {
tracing::error!("channels: failed to bulk-load DM participant profiles: {e}");
vec![]
});
let user_map: HashMap<String, String> = user_records
.into_iter()
.filter_map(|u| {
let hex = nostr_hex::encode(&u.pubkey);
u.display_name.map(|name| (hex, name))
})
.collect();

// 4. Group members by channel_id for O(1) lookup.
let mut members_by_channel: HashMap<uuid::Uuid, Vec<&sprout_db::channel::MemberRecord>> =
HashMap::new();
for m in &all_dm_members {
members_by_channel.entry(m.channel_id).or_default().push(m);
}

let mut result = Vec::with_capacity(channels.len());

for ac in &channels {
let ch = &ac.channel;
let (participants, participant_pubkeys) = if ch.channel_type == "dm" {
resolve_dm_participants(&state, ch.id).await
let members = members_by_channel.get(&ch.id);
let mut names = Vec::new();
let mut pk_hexes = Vec::new();
if let Some(members) = members {
for m in members {
let hex = nostr_hex::encode(&m.pubkey);
let name = user_map
.get(&hex)
.cloned()
.unwrap_or_else(|| hex[..8.min(hex.len())].to_string());
names.push(name);
pk_hexes.push(hex);
}
}
(names, pk_hexes)
} else {
(vec![], vec![])
};
Expand Down Expand Up @@ -119,46 +184,3 @@ fn channel_record_to_json(
"ttl_deadline": channel.ttl_deadline.map(|t| t.to_rfc3339()),
})
}

/// Fetch DM participants and resolve their display names.
async fn resolve_dm_participants(
state: &AppState,
channel_id: uuid::Uuid,
) -> (Vec<String>, Vec<String>) {
let members = state.db.get_members(channel_id).await.unwrap_or_else(|e| {
tracing::error!("channels: failed to load members for channel {channel_id}: {e}");
vec![]
});

let member_pubkeys: Vec<Vec<u8>> = members.iter().map(|m| m.pubkey.clone()).collect();

let user_records = state
.db
.get_users_bulk(&member_pubkeys)
.await
.unwrap_or_else(|e| {
tracing::error!("channels: failed to load user records for DM participants: {e}");
vec![]
});

let user_map: HashMap<String, String> = user_records
.into_iter()
.filter_map(|u| {
let hex = nostr_hex::encode(&u.pubkey);
u.display_name.map(|name| (hex, name))
})
.collect();

let mut names = Vec::new();
let mut pk_hexes = Vec::new();
for m in &members {
let hex = nostr_hex::encode(&m.pubkey);
let name = user_map
.get(&hex)
.cloned()
.unwrap_or_else(|| hex[..8.min(hex.len())].to_string());
names.push(name);
pk_hexes.push(hex);
}
(names, pk_hexes)
}
20 changes: 16 additions & 4 deletions crates/sprout-relay/src/api/dms.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,14 @@ pub async fn open_dm_handler(
.map_err(|e| internal_error(&format!("db error: {e}")))?;

if was_created {
// Invalidate membership + accessible-channels caches for all participants
// so REQ, /api/feed, and /api/search immediately include the new DM.
// Note: DM hide/unhide does NOT need cache invalidation because
// get_accessible_channel_ids() does not filter on hidden_at.
for pk in &all_bytes {
state.invalidate_membership(channel.id, pk);
}

let actor_hex = nostr_hex::encode(&self_bytes);
let participant_hexes: Vec<String> = all_bytes.iter().map(nostr_hex::encode).collect();
if let Err(e) = emit_system_message(
Expand Down Expand Up @@ -199,8 +207,7 @@ pub async fn add_dm_member_handler(

// Verify caller is a member of the existing DM.
let is_member = state
.db
.is_member(channel_id, &self_bytes)
.is_member_cached(channel_id, &self_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;
if !is_member {
Expand Down Expand Up @@ -262,6 +269,12 @@ pub async fn add_dm_member_handler(
.map_err(|e| internal_error(&format!("db error: {e}")))?;

if was_created {
// Invalidate membership + accessible-channels caches for all participants
// so REQ, /api/feed, and /api/search immediately include the new DM.
for pk in &all_bytes {
state.invalidate_membership(new_channel.id, pk);
}

// Emit NIP-29 group discovery events for the new expanded DM.
if let Err(e) = emit_group_discovery_events(&state, new_channel.id).await {
tracing::warn!(channel = %new_channel.id, "DM discovery emission failed: {e}");
Expand Down Expand Up @@ -393,8 +406,7 @@ pub async fn hide_dm_handler(

// Verify caller is a member.
let is_member = state
.db
.is_member(channel_id, &ctx.pubkey_bytes)
.is_member_cached(channel_id, &ctx.pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;

Expand Down
3 changes: 1 addition & 2 deletions crates/sprout-relay/src/api/feed.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,7 @@ pub async fn feed_handler(

let accessible_ids = constrain_channel_ids(
state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.get_accessible_channel_ids_cached(&pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?,
ctx.channel_ids.as_deref(),
Expand Down
38 changes: 20 additions & 18 deletions crates/sprout-relay/src/api/media.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,26 +159,28 @@ pub async fn upload_blob(
};
metrics::counter!("sprout_media_uploads_total", "mime" => mime_label.to_owned()).increment(1);

// Fire-and-forget audit — never block the response on audit I/O.
let audit = state.audit.clone();
// Audit via bounded channel — same pattern as event audit.
let desc = descriptor.clone();
let uploader = auth.auth_event.pubkey.to_hex();
tokio::spawn(async move {
let _ = audit
.log(NewAuditEntry {
event_id: desc.sha256.clone(),
event_kind: sprout_core::kind::KIND_MEDIA_UPLOAD,
actor_pubkey: uploader,
action: AuditAction::MediaUploaded,
channel_id: None,
metadata: serde_json::json!({
"sha256": desc.sha256,
"size": desc.size,
"mime": desc.mime_type,
}),
})
.await;
});
if let Err(e) = state
.audit_tx
.send(NewAuditEntry {
event_id: desc.sha256.clone(),
event_kind: sprout_core::kind::KIND_MEDIA_UPLOAD,
actor_pubkey: uploader,
action: AuditAction::MediaUploaded,
channel_id: None,
metadata: serde_json::json!({
"sha256": desc.sha256,
"size": desc.size,
"mime": desc.mime_type,
}),
})
.await
{
tracing::error!("Media audit channel closed — entry lost: {e}");
metrics::counter!("sprout_audit_send_errors_total").increment(1);
}

Ok(Json(descriptor))
}
Expand Down
3 changes: 1 addition & 2 deletions crates/sprout-relay/src/api/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -520,8 +520,7 @@ pub(crate) async fn check_channel_membership(
pubkey_bytes: &[u8],
) -> Result<(), (StatusCode, Json<serde_json::Value>)> {
let is_member = state
.db
.is_member(channel_id, pubkey_bytes)
.is_member_cached(channel_id, pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;
if is_member {
Expand Down
3 changes: 1 addition & 2 deletions crates/sprout-relay/src/api/search.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ pub async fn search_handler(

let channel_ids = constrain_channel_ids(
state
.db
.get_accessible_channel_ids(&pubkey_bytes)
.get_accessible_channel_ids_cached(&pubkey_bytes)
.await
.unwrap_or_default(),
ctx.channel_ids.as_deref(),
Expand Down
3 changes: 1 addition & 2 deletions crates/sprout-relay/src/api/tokens.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,7 @@ pub async fn post_tokens(

// Verify caller is a member of the channel.
let is_member = state
.db
.is_member(cid, &ctx.pubkey_bytes)
.is_member_cached(cid, &ctx.pubkey_bytes)
.await
.map_err(|e| internal_error(&format!("db error: {e}")))?;
if !is_member {
Expand Down
7 changes: 3 additions & 4 deletions crates/sprout-relay/src/audio/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -512,8 +512,7 @@ async fn ensure_membership(

// Fast path: already a member.
let is_member = state
.db
.is_member(channel_id, pubkey_bytes)
.is_member_cached(channel_id, pubkey_bytes)
.await
.map_err(|e| format!("db error: {e}"))?;

Expand All @@ -537,8 +536,7 @@ async fn ensure_membership(
if channel.ttl_seconds.is_some() {
if let Some(parent_id) = parent_channel_id {
let parent_member = state
.db
.is_member(parent_id, pubkey_bytes)
.is_member_cached(parent_id, pubkey_bytes)
.await
.map_err(|e| format!("db error: {e}"))?;

Expand All @@ -553,6 +551,7 @@ async fn ensure_membership(
)
.await
.map_err(|e| format!("auto-add failed: {e}"))?;
state.invalidate_membership(channel_id, pubkey_bytes);

return Ok(());
}
Expand Down
Loading