Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 80 additions & 10 deletions payjoin-directory/src/db/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ struct V2WaitMapEntry {

#[derive(Debug)]
struct V1WaitMapEntry {
payload: Arc<Vec<u8>>,
/// The V1 payload. `take()`n after the first read for data minimization —
/// plaintext PSBTs should not linger in memory longer than needed.
payload: Option<Arc<Vec<u8>>>,
sender: oneshot::Sender<Vec<u8>>,
}

Expand Down Expand Up @@ -325,9 +327,12 @@ impl DbTrait for Db {
impl Mailboxes {
async fn read(&mut self, id: &ShortId) -> io::Result<Option<Arc<Vec<u8>>>> {
// V1 POST requests are only stored in memory since they are
// unencrypted. Check this hash table first.
if let Some(V1WaitMapEntry { payload, .. }) = self.pending_v1.get(id) {
return Ok(Some(payload.clone()));
// unencrypted. Check this hash table first. Use take() for data
// minimization — clear the plaintext payload after first read.
if let Some(entry) = self.pending_v1.get_mut(id) {
if let Some(payload) = entry.payload.take() {
return Ok(Some(payload));
}
}

// V2 requests are stored on disk
Expand Down Expand Up @@ -358,8 +363,11 @@ impl Mailboxes {
return Err(Error::OverCapacity);
}

if self.pending_v1.contains_key(id) {
return Err(Error::OverCapacity);
if let Some(entry) = self.pending_v1.get(id) {
Comment thread
spacebear21 marked this conversation as resolved.
if entry.payload.is_some() {
return Err(Error::OverCapacity);
}
return Err(Error::AlreadyRead);
}

let receiver = self
Expand Down Expand Up @@ -419,13 +427,17 @@ impl Mailboxes {
let payload = payload.clone();
let (sender, receiver) = oneshot::channel::<Vec<u8>>();
ret = Some(receiver);
V1WaitMapEntry { payload, sender }
V1WaitMapEntry { payload: Some(payload), sender }
});

// If there are pending readers, satisfy them and mark the payload as read
// If there are pending readers, satisfy them with the payload
// and clear the in-memory copy for data minimization
if let Some(pending) = self.pending_v2.remove(id) {
trace!("notifying pending readers for {} (v1 fallback)", id);
pending.sender.send(payload).expect("sending on oneshot channel must succeed");
pending.sender.send(payload.clone()).expect("sending on oneshot channel must succeed");
if let Some(entry) = self.pending_v1.get_mut(id) {
entry.payload.take();
}
}

Ok(ret)
Expand Down Expand Up @@ -568,6 +580,9 @@ pub enum Error {
/// Operation rejected due to lack of capacity
OverCapacity,

/// Indicates receiver already consumed the plaintext V1 request payload
AlreadyRead,

/// Indicates the sender that was waiting for the reply is no longer there
V1SenderUnavailable,

Expand All @@ -584,6 +599,7 @@ impl From<Error> for super::Error<std::io::Error> {
match val {
Error::V1SenderUnavailable => super::Error::V1SenderUnavailable,
Error::OverCapacity => super::Error::OverCapacity,
Error::AlreadyRead => super::Error::AlreadyRead,
Error::IO(e) => super::Error::Operational(e),
}
}
Expand All @@ -603,6 +619,7 @@ impl std::fmt::Display for Error {
use Error::*;
match self {
OverCapacity => "Database over capacity".fmt(f),
AlreadyRead => "Mailbox payload already read".fmt(f),
V1SenderUnavailable => "Sender no longer connected".fmt(f),
IO(e) => write!(f, "Internal Error: {e}"),
}
Expand Down Expand Up @@ -780,7 +797,7 @@ async fn test_v2_wait() -> std::io::Result<()> {

match db.wait_for_v2_payload(&id).await {
Err(super::Error::Timeout(_)) => {}
res => panic!("expected timeout, got {:?}", res),
res => panic!("expected timeout, got {res:?}"),
}

let read_task1 = tokio::spawn({
Expand Down Expand Up @@ -870,6 +887,59 @@ async fn test_v1_wait() -> std::io::Result<()> {
Ok(())
}

#[tokio::test]
async fn test_v1_data_minimization() -> std::io::Result<()> {
let dir = tempfile::tempdir()?;

let db = Arc::new(
Db::init(Duration::from_millis(500), dir.path().to_owned())
.await
.expect("initializing mailbox database should succeed"),
);

let id = ShortId([0u8; 8]);

// Spawn v1 sender in background
let v1_sender_task = tokio::spawn({
let db = db.clone();
async move { db.post_v1_request_and_wait_for_response(&id, b"request".to_vec()).await }
});

// Small delay to let v1 request post
tokio::time::sleep(Duration::from_millis(10)).await;

// First read should return the payload
let res = db.wait_for_v2_payload(&id).await.expect("first read should succeed");
assert_eq!(&res[..], b"request", "first read should return the payload");

// Subsequent reads should not return the plaintext payload again.
assert!(
matches!(db.wait_for_v2_payload(&id).await, Err(super::Error::AlreadyRead)),
"subsequent reads should indicate the payload was already consumed"
);

// Verify the payload was cleared from memory by checking directly
{
let guard = db.mailboxes.lock().await;
let entry = guard.pending_v1.get(&id);
assert!(
entry.is_none_or(|e| e.payload.is_none()),
"v1 payload should have been cleared after first read"
);
}

// V1 response flow should still work even after payload was cleared
db.post_v1_response(&id, b"response".to_vec()).await.expect("posting response should succeed");

let res = v1_sender_task
.await
.expect("joining task should succeed")
.expect("v1 sender should get response");
assert_eq!(&res[..], b"response", "v1 sender should receive the response");

Ok(())
}

// Simulate elapsed time deterministically by shifting stored timestamps
// backward instead of sleeping. tokio::time::pause() can't be used because
// prune compares against SystemTime (timestamps originate from disk).
Expand Down
2 changes: 2 additions & 0 deletions payjoin-directory/src/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ pub enum Error<OperationalError: SendableError> {
Operational(OperationalError),
Timeout(tokio::time::error::Elapsed),
OverCapacity,
AlreadyRead,
V1SenderUnavailable,
}

Expand All @@ -33,6 +34,7 @@ impl<E: SendableError> std::fmt::Display for Error<E> {
Operational(error) => write!(f, "Db error: {error}"),
Timeout(timeout) => write!(f, "Timeout: {timeout}"),
OverCapacity => "Database over capacity".fmt(f),
AlreadyRead => "Mailbox payload already read".fmt(f),
V1SenderUnavailable => "Sender no longer connected".fmt(f),
}
}
Expand Down
1 change: 1 addition & 0 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ fn handle_peek<Error: db::SendableError>(
db::Error::OverCapacity => Err(HandlerError::ServiceUnavailable(anyhow::Error::msg(
"mailbox storage at capacity",
))),
db::Error::AlreadyRead => Ok(timeout_response),
db::Error::V1SenderUnavailable => Err(HandlerError::SenderGone(anyhow::Error::msg(
"Sender is unavailable try a new request",
))),
Expand Down
Loading