diff --git a/rust/lance-io/src/local.rs b/rust/lance-io/src/local.rs index 8b2cc4e0453..a08f81460e2 100644 --- a/rust/lance-io/src/local.rs +++ b/rust/lance-io/src/local.rs @@ -17,6 +17,7 @@ use std::os::windows::fs::FileExt; use async_trait::async_trait; use bytes::{Bytes, BytesMut}; use deepsize::DeepSizeOf; +use futures::future::BoxFuture; use lance_core::{Error, Result}; use object_store::path::Path; use snafu::location; @@ -153,7 +154,6 @@ impl LocalObjectReader { } } -#[async_trait] impl Reader for LocalObjectReader { fn path(&self) -> &Path { &self.path @@ -168,80 +168,86 @@ impl Reader for LocalObjectReader { } /// Returns the file size. - async fn size(&self) -> object_store::Result { - let file = self.file.clone(); - self.size - .get_or_try_init(|| async move { - let metadata = tokio::task::spawn_blocking(move || { - file.metadata().map_err(|err| object_store::Error::Generic { - store: "LocalFileSystem", - source: err.into(), + fn size(&self) -> BoxFuture<'_, object_store::Result> { + Box::pin(async move { + let file = self.file.clone(); + self.size + .get_or_try_init(|| async move { + let metadata = tokio::task::spawn_blocking(move || { + file.metadata().map_err(|err| object_store::Error::Generic { + store: "LocalFileSystem", + source: err.into(), + }) }) + .await??; + Ok(metadata.len() as usize) }) - .await??; - Ok(metadata.len() as usize) - }) - .await - .cloned() + .await + .cloned() + }) } /// Reads a range of data. #[instrument(level = "debug", skip(self))] - async fn get_range(&self, range: Range) -> object_store::Result { + fn get_range(&self, range: Range) -> BoxFuture<'static, object_store::Result> { let file = self.file.clone(); let io_tracker = self.io_tracker.clone(); let path = self.path.clone(); let num_bytes = range.len() as u64; let range_u64 = (range.start as u64)..(range.end as u64); - let result = tokio::task::spawn_blocking(move || { - let mut buf = BytesMut::with_capacity(range.len()); - // Safety: `buf` is set with appropriate capacity above. It is - // written to below and we check all data is initialized at that point. - unsafe { buf.set_len(range.len()) }; - #[cfg(unix)] - file.read_exact_at(buf.as_mut(), range.start as u64)?; - #[cfg(windows)] - read_exact_at(file, buf.as_mut(), range.start as u64)?; - - Ok(buf.freeze()) - }) - .await? - .map_err(|err: std::io::Error| object_store::Error::Generic { - store: "LocalFileSystem", - source: err.into(), - }); - - if result.is_ok() { - io_tracker.record_read("get_range", path, num_bytes, Some(range_u64)); - } + Box::pin(async move { + let result = tokio::task::spawn_blocking(move || { + let mut buf = BytesMut::with_capacity(range.len()); + // Safety: `buf` is set with appropriate capacity above. It is + // written to below and we check all data is initialized at that point. + unsafe { buf.set_len(range.len()) }; + #[cfg(unix)] + file.read_exact_at(buf.as_mut(), range.start as u64)?; + #[cfg(windows)] + read_exact_at(file, buf.as_mut(), range.start as u64)?; + + Ok(buf.freeze()) + }) + .await? + .map_err(|err: std::io::Error| object_store::Error::Generic { + store: "LocalFileSystem", + source: err.into(), + }); + + if result.is_ok() { + io_tracker.record_read("get_range", path, num_bytes, Some(range_u64)); + } - result + result + }) } /// Reads the entire file. #[instrument(level = "debug", skip(self))] - async fn get_all(&self) -> object_store::Result { - let mut file = self.file.clone(); - let io_tracker = self.io_tracker.clone(); - let path = self.path.clone(); + fn get_all(&self) -> BoxFuture<'_, object_store::Result> { + Box::pin(async move { + let mut file = self.file.clone(); + let io_tracker = self.io_tracker.clone(); + let path = self.path.clone(); + + let result = tokio::task::spawn_blocking(move || { + let mut buf = Vec::new(); + file.read_to_end(buf.as_mut())?; + Ok(Bytes::from(buf)) + }) + .await? + .map_err(|err: std::io::Error| object_store::Error::Generic { + store: "LocalFileSystem", + source: err.into(), + }); + + if let Ok(bytes) = &result { + io_tracker.record_read("get_all", path, bytes.len() as u64, None); + } - let result = tokio::task::spawn_blocking(move || { - let mut buf = Vec::new(); - file.read_to_end(buf.as_mut())?; - Ok(Bytes::from(buf)) + result }) - .await? - .map_err(|err: std::io::Error| object_store::Error::Generic { - store: "LocalFileSystem", - source: err.into(), - }); - - if let Ok(bytes) = &result { - io_tracker.record_read("get_all", path, bytes.len() as u64, None); - } - - result } } diff --git a/rust/lance-io/src/object_reader.rs b/rust/lance-io/src/object_reader.rs index 3f79daca540..b81a3d75752 100644 --- a/rust/lance-io/src/object_reader.rs +++ b/rust/lance-io/src/object_reader.rs @@ -4,7 +4,6 @@ use std::ops::Range; use std::sync::Arc; -use async_trait::async_trait; use bytes::Bytes; use deepsize::DeepSizeOf; use futures::{ @@ -18,6 +17,35 @@ use tracing::instrument; use crate::{object_store::DEFAULT_CLOUD_IO_PARALLELISM, traits::Reader}; +trait StaticGetRange { + fn path(&self) -> &Path; + fn get_range(&self) -> BoxFuture<'static, OSResult>; +} + +/// A wrapper around an object store and a path that implements a static +/// get_range method by assuming self is stored in an Arc. +struct GetRequest { + object_store: Arc, + path: Path, + options: GetOptions, +} + +impl StaticGetRange for Arc { + fn path(&self) -> &Path { + &self.path + } + + fn get_range(&self) -> BoxFuture<'static, OSResult> { + let store_and_path = self.clone(); + Box::pin(async move { + store_and_path + .object_store + .get_opts(&store_and_path.path, store_and_path.options.clone()) + .await + }) + } +} + /// Object Reader /// /// Object Store + Base Path @@ -58,64 +86,62 @@ impl CloudObjectReader { download_retry_count, }) } +} - // Retries for the initial request are handled by object store, but - // there are no retries for failures that occur during the streaming - // of the response body. Thus we add an outer retry loop here. - async fn do_with_retry<'a, O>( - &self, - f: impl Fn() -> BoxFuture<'a, OSResult>, - ) -> OSResult { - let mut retries = 3; - loop { - match f().await { - Ok(val) => return Ok(val), - Err(err) => { - if retries == 0 { - return Err(err); - } - retries -= 1; +// Retries for the initial request are handled by object store, but +// there are no retries for failures that occur during the streaming +// of the response body. Thus we add an outer retry loop here. +async fn do_with_retry<'a, O>(f: impl Fn() -> BoxFuture<'a, OSResult> + Clone) -> OSResult { + let mut retries = 3; + loop { + let f = f.clone(); + match f().await { + Ok(val) => return Ok(val), + Err(err) => { + if retries == 0 { + return Err(err); } + retries -= 1; } } } +} - // We have a separate retry loop here. This is because object_store does not - // attempt retries on downloads that fail during streaming of the response body. - // - // However, this failure is pretty common (e.g. timeout) and we want to retry in these - // situations. In addition, we provide additional logging information in these - // failures cases. - async fn do_get_with_outer_retry<'a>( - &self, - f: impl Fn() -> BoxFuture<'a, OSResult> + Copy, - desc: impl Fn() -> String, - ) -> OSResult { - let mut retries = self.download_retry_count; - loop { - let get_result = self.do_with_retry(f).await?; - match get_result.bytes().await { - Ok(bytes) => return Ok(bytes), - Err(err) => { - if retries == 0 { - log::warn!("Failed to download {} from {} after {} attempts. This may indicate that cloud storage is overloaded or your timeout settings are too restrictive. Error details: {:?}", desc(), self.path, self.download_retry_count, err); - return Err(err); - } - log::debug!( - "Retrying {} from {} (remaining retries: {}). Error details: {:?}", - desc(), - self.path, - retries, - err - ); - retries -= 1; +// We have a separate retry loop here. This is because object_store does not +// attempt retries on downloads that fail during streaming of the response body. +// +// However, this failure is pretty common (e.g. timeout) and we want to retry in these +// situations. In addition, we provide additional logging information in these +// failures cases. +async fn do_get_with_outer_retry( + download_retry_count: usize, + get_request: Arc, + desc: impl Fn() -> String, +) -> OSResult { + let mut retries = download_retry_count; + loop { + let get_request_clone = get_request.clone(); + let get_result = do_with_retry(move || get_request_clone.get_range()).await?; + match get_result.bytes().await { + Ok(bytes) => return Ok(bytes), + Err(err) => { + if retries == 0 { + log::warn!("Failed to download {} from {} after {} attempts. This may indicate that cloud storage is overloaded or your timeout settings are too restrictive. Error details: {:?}", desc(), get_request.path(), download_retry_count, err); + return Err(err); } + log::debug!( + "Retrying {} from {} (remaining retries: {}). Error details: {:?}", + desc(), + get_request.path(), + retries, + err + ); + retries -= 1; } } } } -#[async_trait] impl Reader for CloudObjectReader { fn path(&self) -> &Path { &self.path @@ -130,52 +156,64 @@ impl Reader for CloudObjectReader { } /// Object/File Size. - async fn size(&self) -> object_store::Result { - self.size - .get_or_try_init(|| async move { - let meta = self - .do_with_retry(|| self.object_store.head(&self.path)) - .await?; - Ok(meta.size as usize) - }) - .await - .cloned() + fn size(&self) -> BoxFuture<'_, object_store::Result> { + Box::pin(async move { + self.size + .get_or_try_init(|| async move { + let meta = do_with_retry(|| self.object_store.head(&self.path)).await?; + Ok(meta.size as usize) + }) + .await + .cloned() + }) } #[instrument(level = "debug", skip(self))] - async fn get_range(&self, range: Range) -> OSResult { - self.do_get_with_outer_retry( - || { - let options = GetOptions { - range: Some( - Range { - start: range.start as u64, - end: range.end as u64, - } - .into(), - ), - ..Default::default() - }; - self.object_store.get_opts(&self.path, options) + fn get_range(&self, range: Range) -> BoxFuture<'static, OSResult> { + let get_request = Arc::new(GetRequest { + object_store: self.object_store.clone(), + path: self.path.clone(), + options: GetOptions { + range: Some( + Range { + start: range.start as u64, + end: range.end as u64, + } + .into(), + ), + ..Default::default() }, - || format!("range {:?}", range), - ) - .await + }); + Box::pin(do_get_with_outer_retry( + self.download_retry_count, + get_request, + move || format!("range {:?}", range), + )) } #[instrument(level = "debug", skip_all)] - async fn get_all(&self) -> OSResult { - self.do_get_with_outer_retry( - || { - self.object_store - .get_opts(&self.path, GetOptions::default()) - }, - || "read_all".to_string(), - ) - .await + fn get_all(&self) -> BoxFuture<'_, OSResult> { + let get_request = Arc::new(GetRequest { + object_store: self.object_store.clone(), + path: self.path.clone(), + options: GetOptions::default(), + }); + Box::pin(async move { + do_get_with_outer_retry(self.download_retry_count, get_request, || { + "read_all".to_string() + }) + .await + }) } } +#[derive(Debug)] +pub struct SmallReaderInner { + path: Path, + size: usize, + state: std::sync::Mutex, +} + /// A reader for a file so small, we just eagerly read it all into memory. /// /// When created, it represents a future that will read the whole file into memory. @@ -183,11 +221,9 @@ impl Reader for CloudObjectReader { /// On the first read call, it will start the read. Multiple threads can call read at the same time. /// /// Once the read is complete, any thread can call read again to get the result. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct SmallReader { - path: Path, - size: usize, - state: Arc>, + inner: Arc, } enum SmallReaderState { @@ -231,12 +267,16 @@ impl SmallReader { .shared(), ); Self { - path, - size, - state: Arc::new(std::sync::Mutex::new(state)), + inner: Arc::new(SmallReaderInner { + path, + size, + state: std::sync::Mutex::new(state), + }), } } +} +impl SmallReaderInner { async fn wait(&self) -> OSResult { let future = { let state = self.state.lock().unwrap(); @@ -258,10 +298,9 @@ impl SmallReader { } } -#[async_trait] impl Reader for SmallReader { fn path(&self) -> &Path { - &self.path + &self.inner.path } fn block_size(&self) -> usize { @@ -273,12 +312,15 @@ impl Reader for SmallReader { } /// Object/File Size. - async fn size(&self) -> OSResult { - Ok(self.size) + fn size(&self) -> BoxFuture<'_, OSResult> { + let size = self.inner.size; + Box::pin(async move { Ok(size) }) } - async fn get_range(&self, range: Range) -> OSResult { - self.wait().await.and_then(|bytes| { + fn get_range(&self, range: Range) -> BoxFuture<'static, OSResult> { + let inner = self.inner.clone(); + Box::pin(async move { + let bytes = inner.wait().await?; let start = range.start; let end = range.end; if start >= bytes.len() || end > bytes.len() { @@ -297,16 +339,16 @@ impl Reader for SmallReader { }) } - async fn get_all(&self) -> OSResult { - self.wait().await + fn get_all(&self) -> BoxFuture<'_, OSResult> { + Box::pin(async move { self.inner.wait().await }) } } impl DeepSizeOf for SmallReader { fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize { - let mut size = self.path.as_ref().deep_size_of_children(context); + let mut size = self.inner.path.as_ref().deep_size_of_children(context); - if let Ok(guard) = self.state.try_lock() { + if let Ok(guard) = self.inner.state.try_lock() { if let SmallReaderState::Finished(Ok(data)) = &*guard { size += data.len(); } diff --git a/rust/lance-io/src/traits.rs b/rust/lance-io/src/traits.rs index 046e4e4a558..4a6631e6ba8 100644 --- a/rust/lance-io/src/traits.rs +++ b/rust/lance-io/src/traits.rs @@ -6,6 +6,7 @@ use std::ops::Range; use async_trait::async_trait; use bytes::Bytes; use deepsize::DeepSizeOf; +use futures::future::BoxFuture; use object_store::path::Path; use prost::Message; use tokio::io::{AsyncWrite, AsyncWriteExt}; @@ -79,7 +80,6 @@ impl WriteExt for W { } } -#[async_trait] pub trait Reader: std::fmt::Debug + Send + Sync + DeepSizeOf { fn path(&self) -> &Path; @@ -90,16 +90,16 @@ pub trait Reader: std::fmt::Debug + Send + Sync + DeepSizeOf { fn io_parallelism(&self) -> usize; /// Object/File Size. - async fn size(&self) -> object_store::Result; + fn size(&self) -> BoxFuture<'_, object_store::Result>; /// Read a range of bytes from the object. /// /// TODO: change to read_at()? - async fn get_range(&self, range: Range) -> object_store::Result; + fn get_range(&self, range: Range) -> BoxFuture<'static, object_store::Result>; /// Read all bytes from the object. /// /// By default this reads the size in a separate IOP but some implementations /// may not need the size beforehand. - async fn get_all(&self) -> object_store::Result; + fn get_all(&self) -> BoxFuture<'_, object_store::Result>; }