diff --git a/src/serve/grpc/mod.rs b/src/serve/grpc/mod.rs index 2e5341322..62ce6486e 100644 --- a/src/serve/grpc/mod.rs +++ b/src/serve/grpc/mod.rs @@ -38,13 +38,13 @@ pub async fn serve( ) -> Result<(), Error> { let addr = config.listen_address.parse().unwrap(); - let sync_service = sync::SyncServiceImpl::new(wal.clone(), ledger.clone(), chain); + let sync_service = sync::SyncServiceImpl::new(wal.clone(), ledger.clone(), chain, exit.clone()); let sync_service = u5c::sync::sync_service_server::SyncServiceServer::new(sync_service); let query_service = query::QueryServiceImpl::new(ledger.clone(), genesis.clone()); let query_service = u5c::query::query_service_server::QueryServiceServer::new(query_service); - let watch_service = watch::WatchServiceImpl::new(wal.clone(), ledger.clone()); + let watch_service = watch::WatchServiceImpl::new(wal.clone(), ledger.clone(), exit.clone()); let watch_service = u5c::watch::watch_service_server::WatchServiceServer::new(watch_service); let submit_service = submit::SubmitServiceImpl::new(mempool, ledger.clone()); diff --git a/src/serve/grpc/sync.rs b/src/serve/grpc/sync.rs index 666a5c869..dd4aec402 100644 --- a/src/serve/grpc/sync.rs +++ b/src/serve/grpc/sync.rs @@ -7,6 +7,7 @@ use pallas::interop::utxorpc::spec::sync::BlockRef; use pallas::interop::utxorpc::{spec as u5c, Mapper}; use pallas::ledger::traverse::MultiEraBlock; use std::pin::Pin; +use tokio_util::sync::CancellationToken; use tonic::{Request, Response, Status}; use crate::chain::ChainStore; @@ -87,14 +88,21 @@ pub struct SyncServiceImpl { wal: wal::redb::WalStore, chain: ChainStore, mapper: interop::Mapper, + cancellation_token: CancellationToken, } impl SyncServiceImpl { - pub fn new(wal: wal::redb::WalStore, ledger: LedgerStore, chain: ChainStore) -> Self { + pub fn new( + wal: wal::redb::WalStore, + ledger: LedgerStore, + chain: ChainStore, + cancellation_token: CancellationToken, + ) -> Self { Self { wal, mapper: Mapper::new(ledger), chain, + cancellation_token, } } } @@ -194,9 +202,10 @@ impl u5c::sync::sync_service_server::SyncService for SyncServiceImpl { let reset = once(async { Ok(point_to_reset_tip_response(point)) }); - let forward = wal::WalStream::start(self.wal.clone(), from_seq) - .skip(1) - .map(move |(_, log)| Ok(wal_log_to_tip_response(&mapper, &log))); + let forward = + wal::WalStream::start(self.wal.clone(), from_seq, self.cancellation_token.clone()) + .skip(1) + .map(move |(_, log)| Ok(wal_log_to_tip_response(&mapper, &log))); let stream = reset.chain(forward); diff --git a/src/serve/grpc/watch.rs b/src/serve/grpc/watch.rs index 94082e16d..c2b6df113 100644 --- a/src/serve/grpc/watch.rs +++ b/src/serve/grpc/watch.rs @@ -11,6 +11,7 @@ use pallas::{ ledger::{addresses::Address, traverse::MultiEraBlock}, }; use std::pin::Pin; +use tokio_util::sync::CancellationToken; use tonic::{Request, Response, Status}; fn outputs_match_address( @@ -196,13 +197,19 @@ fn roll_to_watch_response( pub struct WatchServiceImpl { wal: wal::redb::WalStore, mapper: interop::Mapper, + cancellation_token: CancellationToken, } impl WatchServiceImpl { - pub fn new(wal: wal::redb::WalStore, ledger: LedgerStore) -> Self { + pub fn new( + wal: wal::redb::WalStore, + ledger: LedgerStore, + cancellation_token: CancellationToken, + ) -> Self { Self { wal, mapper: interop::Mapper::new(ledger), + cancellation_token, } } } @@ -241,9 +248,10 @@ impl u5c::watch::watch_service_server::WatchService for WatchServiceImpl { let mapper = self.mapper.clone(); - let stream = wal::WalStream::start(self.wal.clone(), from_seq) - .flat_map(move |(_, log)| roll_to_watch_response(&mapper, &log, &inner_req)) - .map(Ok); + let stream = + wal::WalStream::start(self.wal.clone(), from_seq, self.cancellation_token.clone()) + .flat_map(move |(_, log)| roll_to_watch_response(&mapper, &log, &inner_req)) + .map(Ok); Ok(Response::new(Box::pin(stream))) } diff --git a/src/wal/stream.rs b/src/wal/stream.rs index bb430c944..28bf2caae 100644 --- a/src/wal/stream.rs +++ b/src/wal/stream.rs @@ -1,11 +1,16 @@ use futures_core::Stream; +use tokio_util::sync::CancellationToken; use super::*; pub struct WalStream; impl WalStream { - pub fn start(wal: R, from: super::LogSeq) -> impl Stream + pub fn start( + wal: R, + from: super::LogSeq, + cancellation_token: CancellationToken, + ) -> impl Stream where R: WalReader, { @@ -20,12 +25,18 @@ impl WalStream { } loop { - wal.tip_change().await.unwrap(); - let iter = wal.crawl_from(Some(last_seq)).unwrap().skip(1); - - for entry in iter { - last_seq = entry.0; - yield entry; + tokio::select! { + _ = cancellation_token.cancelled() => { + break; + } + _ = wal.tip_change() => { + let iter = wal.crawl_from(Some(last_seq)).unwrap().skip(1); + + for entry in iter { + last_seq = entry.0; + yield entry; + } + } } } }