diff --git a/src/events.rs b/src/events.rs index ca03e6c..e50a035 100644 --- a/src/events.rs +++ b/src/events.rs @@ -13,6 +13,7 @@ //! Because we don't have a built-in runtime, it's up to the end-user to poll //! [`crate::LiquidityManager::get_and_clear_pending_events()`] to receive events. +use crate::transport::msgs::{LSPS0Response, RequestId}; use std::collections::VecDeque; use std::sync::{Condvar, Mutex}; @@ -55,4 +56,16 @@ impl EventQueue { /// Event which you should probably take some action in response to. #[derive(Debug, Clone, PartialEq, Eq)] -pub enum Event {} +pub struct Event { + /// The id from the request + pub id: RequestId, + /// The result of request + pub result: EventResult, +} + +/// Content of the event +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum EventResult { + /// The LSPS0 response + LSPS0(LSPS0Response), +} diff --git a/src/transport/message_handler.rs b/src/transport/message_handler.rs index 2f28124..c68e116 100644 --- a/src/transport/message_handler.rs +++ b/src/transport/message_handler.rs @@ -1,5 +1,5 @@ -use crate::events::{Event, EventQueue}; -use crate::transport::msgs::{LSPSMessage, RawLSPSMessage, LSPS_MESSAGE_TYPE}; +use crate::events::{Event, EventQueue, EventResult}; +use crate::transport::msgs::{LSPSMessage, RawLSPSMessage, RequestId, LSPS_MESSAGE_TYPE}; use crate::transport::protocol::LSPS0MessageHandler; use bitcoin::secp256k1::PublicKey; @@ -32,7 +32,7 @@ pub(crate) trait ProtocolMessageHandler { fn handle_message( &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey, - ) -> Result<(), LightningError>; + ) -> Result, LightningError>; } /// A configuration for [`LiquidityManager`]. @@ -121,6 +121,25 @@ where { self.pending_events.get_and_clear_pending_events() } + /// Returns the waiting event result and close + pub fn wait_event_result(self, request_id: RequestId) -> EventResult { + loop { + let events = self.pending_events.get_and_clear_pending_events(); + for Event { id, result } in events { + if id == request_id { + return result; + } + } + } + } + + /// This allows the list the protocols of LSPS0 + pub fn list_protocols( + &self, counterparty_node_id: PublicKey, + ) -> Result { + Ok(self.lsps0_message_handler.list_protocols(counterparty_node_id)) + } + fn handle_lsps_message( &self, msg: LSPSMessage, sender_node_id: &PublicKey, ) -> Result<(), lightning::ln::msgs::LightningError> { @@ -129,7 +148,11 @@ where { return Err(LightningError { err: format!("{} did not understand a message we previously sent, maybe they don't support a protocol we are trying to use?", sender_node_id), action: ErrorAction::IgnoreAndLog(Level::Error)}); } LSPSMessage::LSPS0(msg) => { - self.lsps0_message_handler.handle_message(msg, sender_node_id)?; + if let Some(event) = + self.lsps0_message_handler.handle_message(msg, sender_node_id)? + { + self.pending_events.enqueue(event); + } } } Ok(()) diff --git a/src/transport/protocol.rs b/src/transport/protocol.rs index bc9c118..113d390 100644 --- a/src/transport/protocol.rs +++ b/src/transport/protocol.rs @@ -5,6 +5,7 @@ use lightning::util::logger::Level; use std::ops::Deref; use std::sync::{Arc, Mutex}; +use crate::events::{Event, EventResult}; use crate::transport::message_handler::ProtocolMessageHandler; use crate::transport::msgs::{ LSPS0Message, LSPS0Request, LSPS0Response, LSPSMessage, ListProtocolsRequest, @@ -32,13 +33,14 @@ where Self { entropy_source, protocols, pending_messages } } - pub fn list_protocols(&self, counterparty_node_id: PublicKey) { + pub fn list_protocols(&self, counterparty_node_id: PublicKey) -> RequestId { + let request_id = utils::generate_request_id(&self.entropy_source); let msg = LSPS0Message::Request( - utils::generate_request_id(&self.entropy_source), + request_id.clone(), LSPS0Request::ListProtocols(ListProtocolsRequest {}), ); - self.enqueue_message(counterparty_node_id, msg); + request_id } fn enqueue_message(&self, counterparty_node_id: PublicKey, message: LSPS0Message) { @@ -46,27 +48,25 @@ where } fn handle_request( - &self, request_id: RequestId, request: LSPS0Request, counterparty_node_id: &PublicKey, - ) -> Result<(), lightning::ln::msgs::LightningError> { + &self, id: RequestId, request: LSPS0Request, counterparty_node_id: &PublicKey, + ) -> Result, lightning::ln::msgs::LightningError> { match request { LSPS0Request::ListProtocols(_) => { - let msg = LSPS0Message::Response( - request_id, - LSPS0Response::ListProtocols(ListProtocolsResponse { - protocols: self.protocols.clone(), - }), - ); + let response = LSPS0Response::ListProtocols(ListProtocolsResponse { + protocols: self.protocols.clone(), + }); + let msg = LSPS0Message::Response(id.clone(), response.clone()); self.enqueue_message(*counterparty_node_id, msg); - Ok(()) + Ok(Some(Event { id, result: EventResult::LSPS0(response) })) } } } fn handle_response( &self, response: LSPS0Response, counterparty_node_id: &PublicKey, - ) -> Result<(), LightningError> { + ) -> Result, LightningError> { match response { - LSPS0Response::ListProtocols(ListProtocolsResponse { protocols }) => Ok(()), + LSPS0Response::ListProtocols(ListProtocolsResponse { protocols }) => Ok(None), LSPS0Response::ListProtocolsError(ResponseError { code, message, data, .. }) => { Err(LightningError { err: format!( @@ -89,7 +89,7 @@ where fn handle_message( &self, message: Self::ProtocolMessage, counterparty_node_id: &PublicKey, - ) -> Result<(), LightningError> { + ) -> Result, LightningError> { match message { LSPS0Message::Request(request_id, request) => { self.handle_request(request_id, request, counterparty_node_id)