diff --git a/rust/lance-io/src/utils.rs b/rust/lance-io/src/utils.rs index a1792d122f8..c63947803a1 100644 --- a/rust/lance-io/src/utils.rs +++ b/rust/lance-io/src/utils.rs @@ -23,6 +23,8 @@ use crate::{ use crate::{traits::Reader, ReadBatchParams}; use lance_core::{Error, Result}; +pub mod tracking_store; + /// Read a binary array from a [Reader]. /// pub async fn read_binary_array( diff --git a/rust/lance-io/src/utils/tracking_store.rs b/rust/lance-io/src/utils/tracking_store.rs new file mode 100644 index 00000000000..a1b4f3b0a77 --- /dev/null +++ b/rust/lance-io/src/utils/tracking_store.rs @@ -0,0 +1,459 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Make assertions about IO operations to an [ObjectStore]. +//! +//! When testing code that performs IO, you will often want to make assertions +//! about the number of reads and writes performed, the amount of data read or +//! written, and the number of disjoint periods where at least one IO is in-flight. +//! +//! This modules provides [`IOTracker`] which can be used to wrap any object store. +use std::fmt::{Display, Formatter}; +use std::ops::Range; +use std::sync::{atomic::AtomicU16, Arc, Mutex}; + +use bytes::Bytes; +use futures::stream::BoxStream; +use object_store::path::Path; +use object_store::{ + GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, + PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart, +}; + +use crate::object_store::WrappingObjectStore; + +#[derive(Debug, Default, Clone)] +pub struct IOTracker(Arc>); + +impl IOTracker { + pub fn incremental_stats(&self) -> IoStats { + std::mem::take(&mut *self.0.lock().unwrap()) + } +} + +impl WrappingObjectStore for IOTracker { + fn wrap( + &self, + target: Arc, + _storage_options: Option<&std::collections::HashMap>, + ) -> Arc { + Arc::new(IoTrackingStore::new(target, self.0.clone())) + } +} + +#[derive(Debug, Default)] +pub struct IoStats { + pub read_iops: u64, + pub read_bytes: u64, + pub write_iops: u64, + pub write_bytes: u64, + /// Number of disjoint periods where at least one IO is in-flight. + pub num_hops: u64, + pub requests: Vec, +} + +/// Assertions on IO statistics. +/// assert_io_eq!(io_stats, read_iops, 1); +/// assert_io_eq!(io_stats, write_iops, 0, "should be no writes"); +/// assert_io_eq!(io_stats, num_hops, 1, "should be just {}", "one hop"); +#[macro_export] +macro_rules! assert_io_eq { + ($io_stats:expr, $field:ident, $expected:expr) => { + assert_eq!( + $io_stats.$field, $expected, + "Expected {} to be {}, got {}. Requests: {:#?}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests + ); + }; + ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => { + assert_eq!( + $io_stats.$field, $expected, + "Expected {} to be {}, got {}. Requests: {:#?} {}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests, + format_args!($($arg)+) + ); + }; +} + +#[macro_export] +macro_rules! assert_io_gt { + ($io_stats:expr, $field:ident, $expected:expr) => { + assert!( + $io_stats.$field > $expected, + "Expected {} to be > {}, got {}. Requests: {:#?}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests + ); + }; + ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => { + assert!( + $io_stats.$field > $expected, + "Expected {} to be > {}, got {}. Requests: {:#?} {}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests, + format_args!($($arg)+) + ); + }; +} + +#[macro_export] +macro_rules! assert_io_lt { + ($io_stats:expr, $field:ident, $expected:expr) => { + assert!( + $io_stats.$field < $expected, + "Expected {} to be < {}, got {}. Requests: {:#?}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests + ); + }; + ($io_stats:expr, $field:ident, $expected:expr, $($arg:tt)+) => { + assert!( + $io_stats.$field < $expected, + "Expected {} to be < {}, got {}. Requests: {:#?} {}", + stringify!($field), + $expected, + $io_stats.$field, + $io_stats.requests, + format_args!($($arg)+) + ); + }; +} + +// These fields are "dead code" because we just use them right now to display +// in test failure messages through Debug. (The lint ignores Debug impls.) +#[allow(dead_code)] +#[derive(Clone)] +pub struct IoRequestRecord { + pub method: &'static str, + pub path: Path, + pub range: Option>, +} + +impl std::fmt::Debug for IoRequestRecord { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + // For example: "put /path/to/file range: 0-100" + write!( + f, + "IORequest(method={}, path=\"{}\"", + self.method, self.path + )?; + if let Some(range) = &self.range { + write!(f, ", range={:?}", range)?; + } + write!(f, ")")?; + Ok(()) + } +} + +impl Display for IoStats { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + +#[derive(Debug)] +pub struct IoTrackingStore { + target: Arc, + stats: Arc>, + active_requests: Arc, +} + +impl Display for IoTrackingStore { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + +impl IoTrackingStore { + fn new(target: Arc, stats: Arc>) -> Self { + Self { + target, + stats, + active_requests: Arc::new(AtomicU16::new(0)), + } + } + + fn record_read( + &self, + method: &'static str, + path: Path, + num_bytes: u64, + range: Option>, + ) { + let mut stats = self.stats.lock().unwrap(); + stats.read_iops += 1; + stats.read_bytes += num_bytes; + stats.requests.push(IoRequestRecord { + method, + path, + range, + }); + } + + fn record_write(&self, method: &'static str, path: Path, num_bytes: u64) { + let mut stats = self.stats.lock().unwrap(); + stats.write_iops += 1; + stats.write_bytes += num_bytes; + stats.requests.push(IoRequestRecord { + method, + path, + range: None, + }); + } + + fn hop_guard(&self) -> HopGuard { + HopGuard::new(self.active_requests.clone(), self.stats.clone()) + } +} + +#[async_trait::async_trait] +#[deny(clippy::missing_trait_methods)] +impl ObjectStore for IoTrackingStore { + async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { + let _guard = self.hop_guard(); + self.record_write("put", location.to_owned(), bytes.content_length() as u64); + self.target.put(location, bytes).await + } + + async fn put_opts( + &self, + location: &Path, + bytes: PutPayload, + opts: PutOptions, + ) -> OSResult { + let _guard = self.hop_guard(); + self.record_write( + "put_opts", + location.to_owned(), + bytes.content_length() as u64, + ); + self.target.put_opts(location, bytes, opts).await + } + + async fn put_multipart(&self, location: &Path) -> OSResult> { + let _guard = self.hop_guard(); + let target = self.target.put_multipart(location).await?; + Ok(Box::new(IoTrackingMultipartUpload { + target, + stats: self.stats.clone(), + path: location.to_owned(), + _guard, + })) + } + + async fn put_multipart_opts( + &self, + location: &Path, + opts: PutMultipartOptions, + ) -> OSResult> { + let _guard = self.hop_guard(); + let target = self.target.put_multipart_opts(location, opts).await?; + Ok(Box::new(IoTrackingMultipartUpload { + target, + stats: self.stats.clone(), + path: location.to_owned(), + _guard, + })) + } + + async fn get(&self, location: &Path) -> OSResult { + let _guard = self.hop_guard(); + let result = self.target.get(location).await; + if let Ok(result) = &result { + let num_bytes = result.range.end - result.range.start; + self.record_read("get", location.to_owned(), num_bytes, None); + } + result + } + + async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { + let _guard = self.hop_guard(); + let range = match &options.range { + Some(GetRange::Bounded(range)) => Some(range.clone()), + _ => None, // TODO: fill in other options. + }; + let result = self.target.get_opts(location, options).await; + if let Ok(result) = &result { + let num_bytes = result.range.end - result.range.start; + + self.record_read("get_opts", location.to_owned(), num_bytes, range); + } + result + } + + async fn get_range(&self, location: &Path, range: Range) -> OSResult { + let _guard = self.hop_guard(); + let result = self.target.get_range(location, range.clone()).await; + if let Ok(result) = &result { + self.record_read( + "get_range", + location.to_owned(), + result.len() as u64, + Some(range), + ); + } + result + } + + async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { + let _guard = self.hop_guard(); + let result = self.target.get_ranges(location, ranges).await; + if let Ok(result) = &result { + self.record_read( + "get_ranges", + location.to_owned(), + result.iter().map(|b| b.len() as u64).sum(), + None, + ); + } + result + } + + async fn head(&self, location: &Path) -> OSResult { + let _guard = self.hop_guard(); + self.record_read("head", location.to_owned(), 0, None); + self.target.head(location).await + } + + async fn delete(&self, location: &Path) -> OSResult<()> { + let _guard = self.hop_guard(); + self.record_write("delete", location.to_owned(), 0); + self.target.delete(location).await + } + + fn delete_stream<'a>( + &'a self, + locations: BoxStream<'a, OSResult>, + ) -> BoxStream<'a, OSResult> { + self.target.delete_stream(locations) + } + + fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult> { + let _guard = self.hop_guard(); + self.record_read("list", prefix.cloned().unwrap_or_default(), 0, None); + self.target.list(prefix) + } + + fn list_with_offset( + &self, + prefix: Option<&Path>, + offset: &Path, + ) -> BoxStream<'static, OSResult> { + self.record_read( + "list_with_offset", + prefix.cloned().unwrap_or_default(), + 0, + None, + ); + self.target.list_with_offset(prefix, offset) + } + + async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { + let _guard = self.hop_guard(); + self.record_read( + "list_with_delimiter", + prefix.cloned().unwrap_or_default(), + 0, + None, + ); + self.target.list_with_delimiter(prefix).await + } + + async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { + let _guard = self.hop_guard(); + self.record_write("copy", from.to_owned(), 0); + self.target.copy(from, to).await + } + + async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { + let _guard = self.hop_guard(); + self.record_write("rename", from.to_owned(), 0); + self.target.rename(from, to).await + } + + async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + let _guard = self.hop_guard(); + self.record_write("rename_if_not_exists", from.to_owned(), 0); + self.target.rename_if_not_exists(from, to).await + } + + async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { + let _guard = self.hop_guard(); + self.record_write("copy_if_not_exists", from.to_owned(), 0); + self.target.copy_if_not_exists(from, to).await + } +} + +#[derive(Debug)] +struct IoTrackingMultipartUpload { + target: Box, + path: Path, + stats: Arc>, + _guard: HopGuard, +} + +#[async_trait::async_trait] +impl MultipartUpload for IoTrackingMultipartUpload { + async fn abort(&mut self) -> OSResult<()> { + self.target.abort().await + } + + async fn complete(&mut self) -> OSResult { + self.target.complete().await + } + + fn put_part(&mut self, payload: PutPayload) -> UploadPart { + { + let mut stats = self.stats.lock().unwrap(); + stats.write_iops += 1; + stats.write_bytes += payload.content_length() as u64; + stats.requests.push(IoRequestRecord { + method: "put_part", + path: self.path.to_owned(), + range: None, + }); + } + self.target.put_part(payload) + } +} + +#[derive(Debug)] +struct HopGuard { + active_requests: Arc, + stats: Arc>, +} + +impl HopGuard { + fn new(active_requests: Arc, stats: Arc>) -> Self { + active_requests.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + Self { + active_requests, + stats, + } + } +} + +impl Drop for HopGuard { + fn drop(&mut self) { + if self + .active_requests + .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) + == 1 + { + let mut stats = self.stats.lock().unwrap(); + stats.num_hops += 1; + } + } +} diff --git a/rust/lance/src/dataset.rs b/rust/lance/src/dataset.rs index e14cc6c4be1..fa00ce09969 100644 --- a/rust/lance/src/dataset.rs +++ b/rust/lance/src/dataset.rs @@ -2573,6 +2573,8 @@ mod tests { }; use lance_index::scalar::FullTextSearchQuery; use lance_index::{scalar::ScalarIndexParams, vector::DIST_COL, IndexType}; + use lance_io::assert_io_eq; + use lance_io::utils::tracking_store::IOTracker; use lance_io::utils::CachedFileSize; use lance_linalg::distance::MetricType; use lance_table::feature_flags; @@ -2827,7 +2829,7 @@ mod tests { #[tokio::test] async fn test_load_manifest_iops() { // Need to use in-memory for accurate IOPS tracking. - use crate::utils::test::IoTrackingStore; + let io_tracker = Arc::new(IOTracker::default()); // Use consistent session so memory store can be reused. let session = Arc::new(Session::default()); @@ -2842,13 +2844,12 @@ mod tests { ) .unwrap(); let batches = RecordBatchIterator::new(vec![Ok(batch)], schema.clone()); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); let _original_ds = Dataset::write( batches, "memory://test", Some(WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper.clone()), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), session: Some(session.clone()), @@ -2858,12 +2859,12 @@ mod tests { .await .unwrap(); - io_stats.lock().unwrap().read_iops = 0; + let _ = io_tracker.incremental_stats(); //reset let _dataset = DatasetBuilder::from_uri("memory://test") .with_read_params(ReadParams { store_options: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), session: Some(session), @@ -2873,13 +2874,12 @@ mod tests { .await .unwrap(); - let get_iops = || io_stats.lock().unwrap().read_iops; - // There should be only two IOPS: // 1. List _versions directory to get the latest manifest location // 2. Read the manifest file. (The manifest is small enough to be read in one go. // Larger manifests would result in more IOPS.) - assert_eq!(get_iops(), 2); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 2); } #[rstest] diff --git a/rust/lance/src/dataset/fragment.rs b/rust/lance/src/dataset/fragment.rs index 6ad17fa9e81..8043f2abf15 100644 --- a/rust/lance/src/dataset/fragment.rs +++ b/rust/lance/src/dataset/fragment.rs @@ -2484,7 +2484,11 @@ mod tests { use lance_core::ROW_ID; use lance_datagen::{array, gen_batch, RowCount}; use lance_file::version::LanceFileVersion; - use lance_io::object_store::{ObjectStore, ObjectStoreParams}; + use lance_io::{ + assert_io_eq, assert_io_lt, + object_store::{ObjectStore, ObjectStoreParams}, + utils::tracking_store::IOTracker, + }; use pretty_assertions::assert_eq; use rstest::rstest; use v2::writer::FileWriterOptions; @@ -2496,7 +2500,7 @@ mod tests { InsertBuilder, }, session::Session, - utils::test::{StatsHolder, TestDatasetGenerator}, + utils::test::TestDatasetGenerator, }; async fn create_dataset(test_uri: &str, data_storage_version: LanceFileVersion) -> Dataset { @@ -3695,7 +3699,7 @@ mod tests { ) .unwrap(); let session = Arc::new(Session::default()); - let io_stats = Arc::new(StatsHolder::default()); + let io_stats = Arc::new(IOTracker::default()); let write_params = WriteParams { store_params: Some(ObjectStoreParams { object_store_wrapper: Some(io_stats.clone()), @@ -3714,8 +3718,8 @@ mod tests { // Assert file is small (< 4kb) { let stats = io_stats.incremental_stats(); - assert_eq!(stats.write_iops, 3); - assert!(stats.write_bytes < 4096); + assert_io_eq!(stats, write_iops, 3); + assert_io_lt!(stats, write_bytes, 4096); } // Measure IOPS needed to scan all data first time. @@ -3739,7 +3743,7 @@ mod tests { assert_eq!(data.num_columns(), 7); let stats = io_stats.incremental_stats(); - assert_eq!(stats.read_iops, 1); - assert!(stats.read_bytes < 4096); + assert_io_eq!(stats, read_iops, 1); + assert_io_lt!(stats, read_bytes, 4096); } } diff --git a/rust/lance/src/dataset/scanner.rs b/rust/lance/src/dataset/scanner.rs index 9a7b76c6d9e..f77abdd19ee 100644 --- a/rust/lance/src/dataset/scanner.rs +++ b/rust/lance/src/dataset/scanner.rs @@ -3770,7 +3770,6 @@ pub mod test_dataset { mod test { use std::collections::BTreeSet; - use std::sync::Mutex; use std::vec; use arrow::array::as_primitive_array; @@ -3797,7 +3796,9 @@ mod test { use lance_index::vector::pq::PQBuildParams; use lance_index::vector::sq::builder::SQBuildParams; use lance_index::{scalar::ScalarIndexParams, IndexType}; + use lance_io::assert_io_gt; use lance_io::object_store::ObjectStoreParams; + use lance_io::utils::tracking_store::IOTracker; use lance_linalg::distance::DistanceType; use lance_testing::datagen::{BatchGenerator, IncrementingInt32, RandomVector}; use rstest::rstest; @@ -3810,8 +3811,7 @@ mod test { use crate::dataset::WriteParams; use crate::index::vector::{StageParams, VectorIndexParams}; use crate::utils::test::{ - assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount, IoStats, - IoTrackingStore, + assert_plan_node_equals, DatagenExt, FragmentCount, FragmentRowCount, }; #[tokio::test] @@ -6142,6 +6142,7 @@ mod test { #[values(LanceFileVersion::Legacy, LanceFileVersion::Stable)] data_storage_version: LanceFileVersion, ) { + use lance_io::assert_io_lt; // Create a large dataset with a scalar indexed column and a sorted but not scalar // indexed column use lance_table::io::commit::RenameCommitHandler; @@ -6154,13 +6155,13 @@ mod test { .col("not_indexed", array::step::()) .into_reader_rows(RowCount::from(1000), BatchCount::from(20)); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + let io_tracker = Arc::new(IOTracker::default()); let mut dataset = Dataset::write( data, "memory://test", Some(WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), commit_handler: Some(Arc::new(RenameCommitHandler)), @@ -6181,15 +6182,13 @@ mod test { .await .unwrap(); - let get_bytes = || io_stats.lock().unwrap().read_bytes; - // First run a full scan to get a baseline - let start_bytes = get_bytes(); + let _ = io_tracker.incremental_stats(); // reset dataset.scan().try_into_batch().await.unwrap(); - let full_scan_bytes = get_bytes() - start_bytes; + let io_stats = io_tracker.incremental_stats(); + let full_scan_bytes = io_stats.read_bytes; // Next do a scan without pushdown, we should still see a benefit from late materialization - let start_bytes = get_bytes(); dataset .scan() .use_stats(false) @@ -6198,14 +6197,13 @@ mod test { .try_into_batch() .await .unwrap(); - let filtered_scan_bytes = get_bytes() - start_bytes; - - assert!(filtered_scan_bytes < full_scan_bytes); + let io_stats = io_tracker.incremental_stats(); + assert_io_lt!(io_stats, read_bytes, full_scan_bytes); + let filtered_scan_bytes = io_stats.read_bytes; // Now do a scan with pushdown, the benefit should be even greater // Pushdown only works with the legacy format for now. if data_storage_version == LanceFileVersion::Legacy { - let start_bytes = get_bytes(); dataset .scan() .filter("not_indexed = 50") @@ -6213,15 +6211,13 @@ mod test { .try_into_batch() .await .unwrap(); - let pushdown_scan_bytes = get_bytes() - start_bytes; - - assert!(pushdown_scan_bytes < filtered_scan_bytes); + let io_stats = io_tracker.incremental_stats(); + assert_io_lt!(io_stats, read_bytes, filtered_scan_bytes); } // Now do a scalar index scan, this should be better than a // full scan but since we have to load the index might be more // expensive than late / pushdown scan - let start_bytes = get_bytes(); dataset .scan() .filter("indexed = 50") @@ -6229,12 +6225,12 @@ mod test { .try_into_batch() .await .unwrap(); - let index_scan_bytes = get_bytes() - start_bytes; - assert!(index_scan_bytes < full_scan_bytes); + let io_stats = io_tracker.incremental_stats(); + assert_io_lt!(io_stats, read_bytes, full_scan_bytes); + let index_scan_bytes = io_stats.read_bytes; // A second scalar index scan should be cheaper than the first // since we should have the index in cache - let start_bytes = get_bytes(); dataset .scan() .filter("indexed = 50") @@ -6242,8 +6238,8 @@ mod test { .try_into_batch() .await .unwrap(); - let second_index_scan_bytes = get_bytes() - start_bytes; - assert!(second_index_scan_bytes < index_scan_bytes); + let io_stats = io_tracker.incremental_stats(); + assert_io_lt!(io_stats, read_bytes, index_scan_bytes); } #[rstest] @@ -7295,6 +7291,7 @@ mod test { // indexed column use lance_index::scalar::inverted::tokenizer::InvertedIndexParams; + use lance_io::assert_io_eq; let data = gen_batch() .col( "vector", @@ -7305,13 +7302,13 @@ mod test { .col("not_indexed", array::step::()) .into_reader_rows(RowCount::from(100), BatchCount::from(5)); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + let io_tracker = Arc::new(IOTracker::default()); let mut dataset = Dataset::write( data, "memory://test", Some(WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), data_storage_version: Some(data_storage_version), @@ -7367,31 +7364,6 @@ mod test { .await .unwrap(); - struct IopsTracker { - baseline: u64, - new_iops: u64, - io_stats: Arc>, - } - - impl IopsTracker { - fn update(&mut self) { - let iops = self.io_stats.lock().unwrap().read_iops; - self.new_iops = iops - self.baseline; - self.baseline = iops; - } - - fn new_iops(&mut self) -> u64 { - self.update(); - self.new_iops - } - } - - let mut tracker = IopsTracker { - baseline: 0, - new_iops: 0, - io_stats, - }; - // First planning cycle needs to do some I/O to determine what scalar indices are available dataset .scan() @@ -7403,7 +7375,8 @@ mod test { .unwrap(); // First pass will need to perform some IOPs to determine what scalar indices are available - assert!(tracker.new_iops() > 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_gt!(io_stats, read_iops, 0); // Second planning cycle should not perform any I/O dataset @@ -7415,7 +7388,8 @@ mod test { .await .unwrap(); - assert_eq!(tracker.new_iops(), 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 0); dataset .scan() @@ -7426,7 +7400,8 @@ mod test { .await .unwrap(); - assert_eq!(tracker.new_iops(), 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 0); dataset .scan() @@ -7438,7 +7413,8 @@ mod test { .await .unwrap(); - assert_eq!(tracker.new_iops(), 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 0); dataset .scan() @@ -7450,7 +7426,8 @@ mod test { .await .unwrap(); - assert_eq!(tracker.new_iops(), 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 0); } #[rstest] diff --git a/rust/lance/src/dataset/write/commit.rs b/rust/lance/src/dataset/write/commit.rs index bb1ed481378..efa9be58a0e 100644 --- a/rust/lance/src/dataset/write/commit.rs +++ b/rust/lance/src/dataset/write/commit.rs @@ -487,6 +487,8 @@ pub struct BatchCommitResult { mod tests { use arrow::array::{Int32Array, RecordBatch}; use arrow_schema::{DataType, Field as ArrowField, Schema as ArrowSchema}; + use lance_io::utils::tracking_store::IOTracker; + use lance_io::{assert_io_eq, assert_io_gt}; use lance_io::{object_store::ChainedWrappingObjectStore, utils::CachedFileSize}; use lance_table::format::{DataFile, Fragment}; use std::time::Duration; @@ -495,10 +497,7 @@ mod tests { use crate::utils::test::ThrottledStoreWrapper; - use crate::{ - dataset::{InsertBuilder, WriteParams}, - utils::test::StatsHolder, - }; + use crate::dataset::{InsertBuilder, WriteParams}; use super::*; @@ -537,8 +536,7 @@ mod tests { #[tokio::test] async fn test_reuse_session() { // Need to use in-memory for accurate IOPS tracking. - use crate::utils::test::IoTrackingStore; - + let io_tracker = IOTracker::default(); let session = Arc::new(Session::default()); // Create new dataset let schema = Arc::new(ArrowSchema::new(vec![ArrowField::new( @@ -551,9 +549,8 @@ mod tests { vec![Arc::new(Int32Array::from_iter_values(0..10_i32))], ) .unwrap(); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); let store_params = ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(Arc::new(io_tracker.clone())), ..Default::default() }; let dataset = InsertBuilder::new("memory://test") @@ -568,20 +565,9 @@ mod tests { .unwrap(); let dataset = Arc::new(dataset); - let reset_iops = || { - io_stats.lock().unwrap().read_iops = 0; - io_stats.lock().unwrap().write_iops = 0; - }; - let get_new_iops = || { - let read_iops = io_stats.lock().unwrap().read_iops; - let write_iops = io_stats.lock().unwrap().write_iops; - reset_iops(); - (read_iops, write_iops) - }; - - let (initial_reads, initial_writes) = get_new_iops(); - assert!(initial_reads > 0); - assert!(initial_writes > 0); + let io_stats = io_tracker.incremental_stats(); + assert_io_gt!(io_stats, read_iops, 0); + assert_io_gt!(io_stats, write_iops, 0); // Commit transaction 5 times for i in 0..5 { @@ -595,12 +581,12 @@ mod tests { // we shouldn't need to read anything from disk. Except we do need // to check for the latest version to see if we need to do conflict // resolution. - let (reads, writes) = get_new_iops(); - assert_eq!(reads, 1, "i = {}", i); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 1, "check latest version, i = {} ", i); // Should see 2 IOPs: // 1. Write the transaction files // 2. Write (conditional put) the manifest - assert_eq!(writes, 2, "i = {}", i); + assert_io_eq!(io_stats, write_iops, 2, "write txn + manifest, i = {}", i); } // Commit transaction with URI and session @@ -614,9 +600,9 @@ mod tests { // Session should still be re-used // However, the dataset needs to be loaded and the read version checked out, // so an additional 4 IOPs are needed. - let (reads, writes) = get_new_iops(); - assert_eq!(reads, 5); - assert_eq!(writes, 2); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, read_iops, 5, "load dataset + check version"); + assert_io_eq!(io_stats, write_iops, 2, "write txn + manifest"); // Commit transaction with URI and new session. Re-use the store // registry so we see the same store. @@ -629,9 +615,10 @@ mod tests { .unwrap(); assert_eq!(new_ds.manifest().version, 8); // Now we have to load all previous transactions. - let (reads, writes) = get_new_iops(); - assert!(reads > 10); - assert_eq!(writes, 2); + + let io_stats = io_tracker.incremental_stats(); + assert_io_gt!(io_stats, read_iops, 10); + assert_io_eq!(io_stats, write_iops, 2, "write txn + manifest"); } #[tokio::test] @@ -640,10 +627,10 @@ mod tests { // * write txn file (this could be optional one day) // * write manifest let session = Arc::new(Session::default()); - let io_tracker = Arc::new(StatsHolder::default()); + let io_tracker = IOTracker::default(); let write_params = WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_tracker.clone()), + object_store_wrapper: Some(Arc::new(io_tracker.clone())), ..Default::default() }), session: Some(session.clone()), @@ -676,11 +663,11 @@ mod tests { // This could be zero, if we decided to be optimistic. However, that // would mean two wasted write requests (txn + manifest) if there was // a conflict. We choose to be pessimistic for more consistent performance. - assert_eq!(io_stats.read_iops, 1); - assert_eq!(io_stats.write_iops, 2); + assert_io_eq!(io_stats, read_iops, 1); + assert_io_eq!(io_stats, write_iops, 2); // We can't write them in parallel. The transaction file must exist before // we can write the manifest. - assert_eq!(io_stats.num_hops, 3); + assert_io_eq!(io_stats, num_hops, 3); } #[tokio::test] @@ -688,7 +675,7 @@ mod tests { async fn test_commit_conflict_iops(#[values(true, false)] use_cache: bool) { let cache_size = if use_cache { 10_000 } else { 0 }; let session = Arc::new(Session::new(0, cache_size, Default::default())); - let io_tracker = Arc::new(StatsHolder::default()); + let io_tracker = Arc::new(IOTracker::default()); // We need throttled to correctly count num hops. Otherwise, memory store // returns synchronously, and each request is 1 hop. let throttled = Arc::new(ThrottledStoreWrapper { @@ -752,17 +739,19 @@ mod tests { // For total of 3 + 2 * num_other_txns io requests. If we have caching enabled, we can skip 2 * num_other_txns // of those. We should be able to read in 5 hops. if use_cache { - assert_eq!(io_stats.read_iops, 1); // Just list versions - assert_eq!(io_stats.num_hops, 3); + assert_io_eq!(io_stats, read_iops, 1); // Just list versions + assert_io_eq!(io_stats, num_hops, 3); } else { // We need to read the other manifests and transactions. - assert_eq!(io_stats.read_iops, 1 + num_other_txns * 2); + + use lance_io::assert_io_lt; + assert_io_eq!(io_stats, read_iops, 1 + num_other_txns * 2); // It's possible to read the txns for some versions before we // finish reading later versions and so the entire "read versions // and txs" may appear as 1 hop instead of 2. - assert!(io_stats.num_hops <= 5); + assert_io_lt!(io_stats, num_hops, 6); } - assert_eq!(io_stats.write_iops, 2); // txn + manifest + assert_io_eq!(io_stats, write_iops, 2); // txn + manifest } #[tokio::test] diff --git a/rust/lance/src/index.rs b/rust/lance/src/index.rs index 09522a4aa2a..b6e5df3b2d7 100644 --- a/rust/lance/src/index.rs +++ b/rust/lance/src/index.rs @@ -1609,10 +1609,10 @@ mod tests { use crate::dataset::{ReadParams, WriteMode, WriteParams}; use crate::index::vector::VectorIndexParams; use crate::session::Session; - use crate::utils::test::{ - copy_test_data_to_tmp, DatagenExt, FragmentCount, FragmentRowCount, StatsHolder, - }; + use crate::utils::test::{copy_test_data_to_tmp, DatagenExt, FragmentCount, FragmentRowCount}; use arrow_array::Int32Array; + use lance_io::utils::tracking_store::IOTracker; + use lance_io::{assert_io_eq, assert_io_lt}; use super::*; @@ -2370,10 +2370,10 @@ mod tests { #[lance_test_macros::test(tokio::test)] async fn test_load_indices() { let session = Arc::new(Session::default()); - let io_stats = Arc::new(StatsHolder::default()); + let io_tracker = Arc::new(IOTracker::default()); let write_params = WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats.clone()), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), session: Some(session.clone()), @@ -2404,17 +2404,13 @@ mod tests { ) .await .unwrap(); - io_stats.incremental_stats(); // Reset + io_tracker.incremental_stats(); // Reset let indices = dataset.load_indices().await.unwrap(); - let stats = io_stats.incremental_stats(); + let stats = io_tracker.incremental_stats(); // We should already have this cached since we just wrote it. - assert_eq!( - stats.read_iops, 0, - "Read IOPS should be 0. Saw requests: {:?}", - stats.requests - ); - assert_eq!(stats.read_bytes, 0); + assert_io_eq!(stats, read_iops, 0); + assert_io_eq!(stats, read_bytes, 0); assert_eq!(indices.len(), 1); session.index_cache.clear().await; // Clear the cache @@ -2423,7 +2419,7 @@ mod tests { .with_session(session.clone()) .with_read_params(ReadParams { store_options: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats.clone()), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), session: Some(session.clone()), @@ -2432,15 +2428,15 @@ mod tests { .load() .await .unwrap(); - let stats = io_stats.incremental_stats(); // Reset - assert!(stats.read_bytes < 64 * 1024); + let stats = io_tracker.incremental_stats(); // Reset + assert_io_lt!(stats, read_bytes, 64 * 1024); // Because the manifest is so small, we should have opportunistically // cached the indices in memory already. let indices2 = dataset2.load_indices().await.unwrap(); - let stats = io_stats.incremental_stats(); - assert_eq!(stats.read_iops, 0); - assert_eq!(stats.read_bytes, 0); + let stats = io_tracker.incremental_stats(); + assert_io_eq!(stats, read_iops, 0); + assert_io_eq!(stats, read_bytes, 0); assert_eq!(indices2.len(), 1); } diff --git a/rust/lance/src/io/commit/conflict_resolver.rs b/rust/lance/src/io/commit/conflict_resolver.rs index 8a2543f0c7b..1b9cc340099 100644 --- a/rust/lance/src/io/commit/conflict_resolver.rs +++ b/rust/lance/src/io/commit/conflict_resolver.rs @@ -1536,7 +1536,9 @@ mod tests { use arrow_schema::{DataType, Field, Schema}; use lance_core::Error; use lance_file::version::LanceFileVersion; + use lance_io::assert_io_eq; use lance_io::object_store::ObjectStoreParams; + use lance_io::utils::tracking_store::IOTracker; use lance_table::format::IndexMetadata; use lance_table::io::deletion::{deletion_file_path, read_deletion_file}; @@ -1546,14 +1548,13 @@ mod tests { use crate::{ dataset::{CommitBuilder, InsertBuilder, WriteParams}, io, - utils::test::StatsHolder, }; - async fn test_dataset(num_rows: usize, num_fragments: usize) -> (Dataset, Arc) { - let io_stats = Arc::new(StatsHolder::default()); + async fn test_dataset(num_rows: usize, num_fragments: usize) -> (Dataset, Arc) { + let io_tracker = Arc::new(IOTracker::default()); let write_params = WriteParams { store_params: Some(ObjectStoreParams { - object_store_wrapper: Some(io_stats.clone()), + object_store_wrapper: Some(io_tracker.clone()), ..Default::default() }), max_rows_per_file: num_rows / num_fragments, @@ -1577,7 +1578,7 @@ mod tests { .execute(vec![data]) .await .unwrap(); - (dataset, io_stats) + (dataset, io_tracker) } /// Helper function for tests to create UpdateConfig operations using old-style parameters @@ -1676,8 +1677,8 @@ mod tests { .check_txn(other_transaction, other_version as u64) .unwrap(); let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); } let expected_transaction = Transaction { @@ -1690,8 +1691,8 @@ mod tests { assert_eq!(rebased_transaction, expected_transaction); // We didn't need to do any IO, so the stats should be 0. let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); } async fn apply_deletion( @@ -1798,8 +1799,8 @@ mod tests { .check_txn(other_transaction, other_version as u64) .unwrap(); let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); } // First iteration, we don't need to rewrite the deletion file. @@ -1811,8 +1812,8 @@ mod tests { let io_stats = io_tracker.incremental_stats(); if expected_rewrite { // Read the current deletion file, and write the new one. - assert_eq!(io_stats.read_iops, 0); // Cached - assert_eq!(io_stats.write_iops, 1); + assert_io_eq!(io_stats, read_iops, 0, "deletion file should be cached"); + assert_io_eq!(io_stats, write_iops, 1, "write one deletion file"); // TODO: The old deletion file should be gone. // This can be done later, as it will be cleaned up by the @@ -1853,11 +1854,11 @@ mod tests { ); assert!(dataset.object_store().exists(&new_path).await.unwrap()); - assert_eq!(io_stats.num_hops, 1); + assert_io_eq!(io_stats, num_hops, 1); } else { // No IO should have happened. - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); } dataset = CommitBuilder::new(Arc::new(dataset)) @@ -1961,8 +1962,8 @@ mod tests { .unwrap(); let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); let res = rebase.check_txn(&other_txn, 1); if other.ends_with("full") || ours.ends_with("full") { @@ -1986,8 +1987,8 @@ mod tests { ); let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); - assert_eq!(io_stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 0); + assert_io_eq!(io_stats, write_iops, 0); let res = rebase.finish(&latest_dataset).await; assert!(matches!( @@ -1996,8 +1997,8 @@ mod tests { )); let io_stats = io_tracker.incremental_stats(); - assert_eq!(io_stats.read_iops, 0); // Cached deletion file - assert_eq!(io_stats.write_iops, 0); // Failed before writing + assert_io_eq!(io_stats, read_iops, 0, "deletion file should be cached"); + assert_io_eq!(io_stats, write_iops, 0, "failed before writing"); } #[derive(Clone, Copy, Debug, PartialEq, Eq)] diff --git a/rust/lance/src/io/commit/s3_test.rs b/rust/lance/src/io/commit/s3_test.rs index 8c534dec1e7..a6e848e0354 100644 --- a/rust/lance/src/io/commit/s3_test.rs +++ b/rust/lance/src/io/commit/s3_test.rs @@ -1,6 +1,6 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::{ops::DerefMut, sync::Arc}; +use std::sync::Arc; use arrow::datatypes::Int32Type; @@ -14,6 +14,8 @@ use aws_config::{BehaviorVersion, ConfigLoader, Region, SdkConfig}; use aws_sdk_s3::{config::Credentials, Client as S3Client}; use futures::future::try_join_all; use lance_datagen::{array, gen_batch, RowCount}; +use lance_io::assert_io_eq; +use lance_io::utils::tracking_store::IOTracker; const CONFIG: &[(&str, &str)] = &[ ("access_key_id", "ACCESS_KEY"), @@ -175,16 +177,14 @@ impl Drop for DynamoDBCommitTable { #[tokio::test] async fn test_concurrent_writers() { - use crate::utils::test::IoTrackingStore; - let datagen = gen_batch().col("values", array::step::()); let data = datagen.into_batch_rows(RowCount::from(100)).unwrap(); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + let io_tracker = Arc::new(IOTracker::default()); // Create a table let store_params = ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(io_tracker.clone()), storage_options: Some( CONFIG .iter() @@ -207,11 +207,8 @@ async fn test_concurrent_writers() { .unwrap(); // 1 IOPS for uncommitted write - let incremental_stats = || { - let mut stats = io_stats.as_ref().lock().unwrap(); - std::mem::take(stats.deref_mut()) - }; - assert_eq!(incremental_stats().write_iops, 1); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, write_iops, 1); let dataset = CommitBuilder::new(&uri) .with_store_params(store_params.clone()) @@ -219,7 +216,8 @@ async fn test_concurrent_writers() { .await .unwrap(); // Commit: 2 IOPs. 1 for transaction file, 1 for manifest file - assert_eq!(incremental_stats().write_iops, 2); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, write_iops, 2); let dataset = Arc::new(dataset); let old_version = dataset.manifest().version; @@ -259,8 +257,6 @@ async fn test_concurrent_writers() { #[tokio::test] async fn test_ddb_open_iops() { - use crate::utils::test::IoTrackingStore; - let bucket = S3Bucket::new("test-ddb-iops").await; let ddb_table = DynamoDBCommitTable::new("test-ddb-iops").await; let uri = format!("s3+ddb://{}/test?ddbTableName={}", bucket.0, ddb_table.0); @@ -268,11 +264,11 @@ async fn test_ddb_open_iops() { let datagen = gen_batch().col("values", array::step::()); let data = datagen.into_batch_rows(RowCount::from(100)).unwrap(); - let (io_stats_wrapper, io_stats) = IoTrackingStore::new_wrapper(); + let io_tracker = Arc::new(IOTracker::default()); // Create a table let store_params = ObjectStoreParams { - object_store_wrapper: Some(io_stats_wrapper), + object_store_wrapper: Some(io_tracker.clone()), storage_options: Some( CONFIG .iter() @@ -293,11 +289,8 @@ async fn test_ddb_open_iops() { .unwrap(); // 1 IOPS for uncommitted write - let incremental_stats = || { - let mut stats = io_stats.as_ref().lock().unwrap(); - std::mem::take(stats.deref_mut()) - }; - assert_eq!(incremental_stats().write_iops, 1); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, write_iops, 1); let _ = CommitBuilder::new(&uri) .with_store_params(store_params.clone()) @@ -310,10 +303,9 @@ async fn test_ddb_open_iops() { // * write staged file // * copy to final file // * delete staged file - let stats = incremental_stats(); - - assert_eq!(stats.write_iops, 4); - assert_eq!(stats.read_iops, 1); + let io_stats = io_tracker.incremental_stats(); + assert_io_eq!(io_stats, write_iops, 4); + assert_io_eq!(io_stats, read_iops, 1); let dataset = DatasetBuilder::from_uri(&uri) .with_read_params(ReadParams { @@ -323,11 +315,11 @@ async fn test_ddb_open_iops() { .load() .await .unwrap(); - let stats = incremental_stats(); + let io_stats = io_tracker.incremental_stats(); // Open dataset can be read with 1 IOP, just to read the manifest. // Looking up latest manifest is handled in dynamodb. - assert_eq!(stats.read_iops, 1); - assert_eq!(stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 1); + assert_io_eq!(io_stats, write_iops, 0); // Append let dataset = InsertBuilder::new(Arc::new(dataset)) @@ -338,17 +330,17 @@ async fn test_ddb_open_iops() { .execute(vec![data.clone()]) .await .unwrap(); - let stats = incremental_stats(); + let io_stats = io_tracker.incremental_stats(); // Append: 5 IOPS: data file, transaction file, 3x manifest file - assert_eq!(stats.write_iops, 5); + assert_io_eq!(io_stats, write_iops, 5); // TODO: we can reduce this by implementing a specialized CommitHandler::list_manifest_locations() // for the DDB commit handler. - assert_eq!(stats.read_iops, 1); + assert_io_eq!(io_stats, read_iops, 1); // Checkout original version dataset.checkout_version(1).await.unwrap(); - let stats = incremental_stats(); + let io_stats = io_tracker.incremental_stats(); // Checkout: 1 IOPS: manifest file - assert_eq!(stats.read_iops, 1); - assert_eq!(stats.write_iops, 0); + assert_io_eq!(io_stats, read_iops, 1); + assert_io_eq!(io_stats, write_iops, 0); } diff --git a/rust/lance/src/utils/test.rs b/rust/lance/src/utils/test.rs index 0e591500c70..49fdb14abb7 100644 --- a/rust/lance/src/utils/test.rs +++ b/rust/lance/src/utils/test.rs @@ -1,30 +1,19 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -use std::fmt::{Display, Formatter}; -use std::ops::Range; -use std::sync::atomic::AtomicU16; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use lance_core::utils::tempfile::{TempDir, TempStrDir}; use snafu::location; use arrow_array::{RecordBatch, RecordBatchIterator}; use arrow_schema::Schema as ArrowSchema; -use bytes::Bytes; use datafusion_physical_plan::ExecutionPlan; -use futures::stream::BoxStream; use lance_arrow::RecordBatchExt; use lance_core::datatypes::Schema; use lance_datagen::{BatchCount, BatchGeneratorBuilder, ByteCount, RowCount}; use lance_file::version::LanceFileVersion; -use lance_io::object_store::WrappingObjectStore; use lance_table::format::Fragment; -use object_store::path::Path; -use object_store::{ - GetOptions, GetRange, GetResult, ListResult, MultipartUpload, ObjectMeta, ObjectStore, - PutMultipartOptions, PutOptions, PutPayload, PutResult, Result as OSResult, UploadPart, -}; use rand::prelude::SliceRandom; use rand::{Rng, SeedableRng}; @@ -276,331 +265,6 @@ fn field_structure(fragment: &Fragment) -> Vec> { .collect::>() } -#[derive(Debug, Default)] -pub struct IoStats { - pub read_iops: u64, - pub read_bytes: u64, - pub write_iops: u64, - pub write_bytes: u64, - /// Number of disjoint periods where at least one IO is in-flight. - pub num_hops: u64, - pub requests: Vec, -} - -// These fields are "dead code" because we just use them right now to display -// in test failure messages through Debug. (The lint ignores Debug impls.) -#[allow(dead_code)] -#[derive(Debug, Clone)] -pub struct IoRequestRecord { - pub method: &'static str, - pub path: Path, - pub range: Option>, -} - -impl Display for IoStats { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:#?}", self) - } -} - -#[derive(Debug)] -pub struct IoTrackingStore { - target: Arc, - stats: Arc>, - active_requests: Arc, -} - -impl Display for IoTrackingStore { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "{:#?}", self) - } -} - -#[derive(Debug, Default, Clone)] -pub struct StatsHolder(Arc>); - -impl StatsHolder { - pub fn incremental_stats(&self) -> IoStats { - std::mem::take(&mut *self.0.lock().unwrap()) - } -} - -impl WrappingObjectStore for StatsHolder { - fn wrap( - &self, - target: Arc, - _storage_options: Option<&std::collections::HashMap>, - ) -> Arc { - Arc::new(IoTrackingStore { - target, - stats: self.0.clone(), - active_requests: Arc::new(AtomicU16::new(0)), - }) - } -} - -impl IoTrackingStore { - pub fn new_wrapper() -> (Arc, Arc>) { - let stats = Arc::new(Mutex::new(IoStats::default())); - (Arc::new(StatsHolder(stats.clone())), stats) - } - - fn record_read( - &self, - method: &'static str, - path: Path, - num_bytes: u64, - range: Option>, - ) { - let mut stats = self.stats.lock().unwrap(); - stats.read_iops += 1; - stats.read_bytes += num_bytes; - stats.requests.push(IoRequestRecord { - method, - path, - range, - }); - } - - fn record_write(&self, num_bytes: u64) { - let mut stats = self.stats.lock().unwrap(); - stats.write_iops += 1; - stats.write_bytes += num_bytes; - } - - fn hop_guard(&self) -> HopGuard { - HopGuard::new(self.active_requests.clone(), self.stats.clone()) - } -} - -#[async_trait::async_trait] -#[deny(clippy::missing_trait_methods)] -impl ObjectStore for IoTrackingStore { - async fn put(&self, location: &Path, bytes: PutPayload) -> OSResult { - let _guard = self.hop_guard(); - self.record_write(bytes.content_length() as u64); - self.target.put(location, bytes).await - } - - async fn put_opts( - &self, - location: &Path, - bytes: PutPayload, - opts: PutOptions, - ) -> OSResult { - let _guard = self.hop_guard(); - self.record_write(bytes.content_length() as u64); - self.target.put_opts(location, bytes, opts).await - } - - async fn put_multipart(&self, location: &Path) -> OSResult> { - let _guard = self.hop_guard(); - let target = self.target.put_multipart(location).await?; - Ok(Box::new(IoTrackingMultipartUpload { - target, - stats: self.stats.clone(), - _guard, - })) - } - - async fn put_multipart_opts( - &self, - location: &Path, - opts: PutMultipartOptions, - ) -> OSResult> { - let _guard = self.hop_guard(); - let target = self.target.put_multipart_opts(location, opts).await?; - Ok(Box::new(IoTrackingMultipartUpload { - target, - stats: self.stats.clone(), - _guard, - })) - } - - async fn get(&self, location: &Path) -> OSResult { - let _guard = self.hop_guard(); - let result = self.target.get(location).await; - if let Ok(result) = &result { - let num_bytes = result.range.end - result.range.start; - self.record_read("get", location.to_owned(), num_bytes, None); - } - result - } - - async fn get_opts(&self, location: &Path, options: GetOptions) -> OSResult { - let _guard = self.hop_guard(); - let range = match &options.range { - Some(GetRange::Bounded(range)) => Some(range.clone()), - _ => None, // TODO: fill in other options. - }; - let result = self.target.get_opts(location, options).await; - if let Ok(result) = &result { - let num_bytes = result.range.end - result.range.start; - - self.record_read("get_opts", location.to_owned(), num_bytes, range); - } - result - } - - async fn get_range(&self, location: &Path, range: Range) -> OSResult { - let _guard = self.hop_guard(); - let result = self.target.get_range(location, range.clone()).await; - if let Ok(result) = &result { - self.record_read( - "get_range", - location.to_owned(), - result.len() as u64, - Some(range), - ); - } - result - } - - async fn get_ranges(&self, location: &Path, ranges: &[Range]) -> OSResult> { - let _guard = self.hop_guard(); - let result = self.target.get_ranges(location, ranges).await; - if let Ok(result) = &result { - self.record_read( - "get_ranges", - location.to_owned(), - result.iter().map(|b| b.len() as u64).sum(), - None, - ); - } - result - } - - async fn head(&self, location: &Path) -> OSResult { - let _guard = self.hop_guard(); - self.record_read("head", location.to_owned(), 0, None); - self.target.head(location).await - } - - async fn delete(&self, location: &Path) -> OSResult<()> { - let _guard = self.hop_guard(); - self.record_write(0); - self.target.delete(location).await - } - - fn delete_stream<'a>( - &'a self, - locations: BoxStream<'a, OSResult>, - ) -> BoxStream<'a, OSResult> { - self.target.delete_stream(locations) - } - - fn list(&self, prefix: Option<&Path>) -> BoxStream<'static, OSResult> { - let _guard = self.hop_guard(); - self.record_read("list", prefix.cloned().unwrap_or_default(), 0, None); - self.target.list(prefix) - } - - fn list_with_offset( - &self, - prefix: Option<&Path>, - offset: &Path, - ) -> BoxStream<'static, OSResult> { - self.record_read( - "list_with_offset", - prefix.cloned().unwrap_or_default(), - 0, - None, - ); - self.target.list_with_offset(prefix, offset) - } - - async fn list_with_delimiter(&self, prefix: Option<&Path>) -> OSResult { - let _guard = self.hop_guard(); - self.record_read( - "list_with_delimiter", - prefix.cloned().unwrap_or_default(), - 0, - None, - ); - self.target.list_with_delimiter(prefix).await - } - - async fn copy(&self, from: &Path, to: &Path) -> OSResult<()> { - let _guard = self.hop_guard(); - self.record_write(0); - self.target.copy(from, to).await - } - - async fn rename(&self, from: &Path, to: &Path) -> OSResult<()> { - let _guard = self.hop_guard(); - self.record_write(0); - self.target.rename(from, to).await - } - - async fn rename_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { - let _guard = self.hop_guard(); - self.record_write(0); - self.target.rename_if_not_exists(from, to).await - } - - async fn copy_if_not_exists(&self, from: &Path, to: &Path) -> OSResult<()> { - let _guard = self.hop_guard(); - self.record_write(0); - self.target.copy_if_not_exists(from, to).await - } -} - -#[derive(Debug)] -struct IoTrackingMultipartUpload { - target: Box, - stats: Arc>, - _guard: HopGuard, -} - -#[async_trait::async_trait] -impl MultipartUpload for IoTrackingMultipartUpload { - async fn abort(&mut self) -> OSResult<()> { - self.target.abort().await - } - - async fn complete(&mut self) -> OSResult { - self.target.complete().await - } - - fn put_part(&mut self, payload: PutPayload) -> UploadPart { - { - let mut stats = self.stats.lock().unwrap(); - stats.write_iops += 1; - stats.write_bytes += payload.content_length() as u64; - } - self.target.put_part(payload) - } -} - -#[derive(Debug)] -struct HopGuard { - active_requests: Arc, - stats: Arc>, -} - -impl HopGuard { - fn new(active_requests: Arc, stats: Arc>) -> Self { - active_requests.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - Self { - active_requests, - stats, - } - } -} - -impl Drop for HopGuard { - fn drop(&mut self) { - if self - .active_requests - .fetch_sub(1, std::sync::atomic::Ordering::SeqCst) - == 1 - { - let mut stats = self.stats.lock().unwrap(); - stats.num_hops += 1; - } - } -} - pub struct FragmentCount(pub u32); impl From for FragmentCount {