From d0ceed23fa2fcea6f43fa6ba88b5ca116f6fc9fc Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Apr 2021 08:40:55 -0400 Subject: [PATCH 1/2] Refactor test utilities into test module --- rust/datafusion/src/execution/context.rs | 140 ++++++++++++++--------- rust/datafusion/src/test/mod.rs | 28 +++++ 2 files changed, 116 insertions(+), 52 deletions(-) diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 4c419d983a6..d07994cc043 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -1752,7 +1752,7 @@ mod tests { async fn limit() -> Result<()> { let tmp_dir = TempDir::new()?; let mut ctx = create_ctx(&tmp_dir, 1)?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1000).unwrap()) .unwrap(); let results = @@ -1788,30 +1788,77 @@ mod tests { Ok(()) } - /// Return a RecordBatch with a single Int32 array with values (0..sz) - fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; + #[tokio::test] + async fn limit_multi_partitions() -> Result<()> { + let tmp_dir = TempDir::new()?; + let mut ctx = create_ctx(&tmp_dir, 1)?; + + let partitions = vec![ + vec![test::make_partition(0)], + vec![test::make_partition(1)], + vec![test::make_partition(2)], + vec![test::make_partition(3)], + vec![test::make_partition(4)], + vec![test::make_partition(5)], + ]; + let schema = partitions[0][0].schema(); + let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - RecordBatch::try_new(schema, vec![arr]).unwrap() + ctx.register_table("t", provider).unwrap(); + + // select all rows + let results = plan_and_collect(&mut ctx, "SELECT i FROM t").await.unwrap(); + + let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); + assert_eq!(num_rows, 15); + + for limit in 1..10 { + let query = format!("SELECT i FROM t limit {}", limit); + let results = plan_and_collect(&mut ctx, &query).await.unwrap(); + + let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); + assert_eq!(num_rows, limit, "mismatch with query {}", query); + } + + Ok(()) } #[tokio::test] - async fn limit_multi_partitions() -> Result<()> { + async fn limit_multi_batch() -> Result<()> { let tmp_dir = TempDir::new()?; let mut ctx = create_ctx(&tmp_dir, 1)?; let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], + vec![ + test::make_partition(0), + test::make_partition(1), + test::make_partition(2), + ], + vec![ + test::make_partition(1), + test::make_partition(2), + test::make_partition(3), + ], + vec![ + test::make_partition(2), + test::make_partition(3), + test::make_partition(4), + ], + vec![ + test::make_partition(3), + test::make_partition(4), + test::make_partition(5), + ], + vec![ + test::make_partition(4), + test::make_partition(5), + test::make_partition(6), + ], + vec![ + test::make_partition(5), + test::make_partition(6), + test::make_partition(7), + ], ]; let schema = partitions[0][0].schema(); let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); @@ -1822,7 +1869,7 @@ mod tests { let results = plan_and_collect(&mut ctx, "SELECT i FROM t").await.unwrap(); let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 15); + assert_eq!(num_rows, 63); for limit in 1..10 { let query = format!("SELECT i FROM t limit {}", limit); @@ -1838,7 +1885,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_functions() { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = vec![ @@ -1878,7 +1925,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); @@ -1918,7 +1965,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_aggregates() { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = vec![ @@ -1958,7 +2005,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // Note capitalizaton @@ -2356,19 +2403,6 @@ mod tests { Ok(()) } - fn table_with_sequence( - seq_start: i32, - seq_end: i32, - ) -> Result> { - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![arr as ArrayRef], - )?]]; - Ok(Arc::new(MemTable::try_new(schema, partitions)?)) - } - #[tokio::test] async fn information_schema_tables_not_exist_by_default() { let mut ctx = ExecutionContext::new(); @@ -2411,7 +2445,7 @@ mod tests { ); // Now, register an empty table - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = @@ -2431,7 +2465,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); // Newly added tables should appear - ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t2", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = @@ -2460,10 +2494,10 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); schema - .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t2".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); @@ -2471,7 +2505,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema - .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t3".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); catalog.register_schema("my_other_schema", Arc::new(schema)); ctx.register_catalog("my_other_catalog", Arc::new(catalog)); @@ -2503,7 +2537,7 @@ mod tests { async fn information_schema_show_tables_no_information_schema() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias @@ -2518,7 +2552,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias @@ -2544,7 +2578,7 @@ mod tests { async fn information_schema_show_columns_no_information_schema() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") @@ -2558,7 +2592,7 @@ mod tests { async fn information_schema_show_columns_like_where() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = @@ -2582,7 +2616,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") @@ -2620,7 +2654,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") @@ -2649,7 +2683,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") @@ -2752,7 +2786,7 @@ mod tests { let schema = MemorySchemaProvider::new(); schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); schema @@ -2790,7 +2824,7 @@ mod tests { ); assert!(matches!( - ctx.register_table("test", table_with_sequence(1, 1)?), + ctx.register_table("test", test::table_with_sequence(1, 1)?), Err(DataFusionError::Plan(_)) )); @@ -2812,7 +2846,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); - schema.register_table("test".to_owned(), table_with_sequence(1, 1)?)?; + schema.register_table("test".to_owned(), test::table_with_sequence(1, 1)?)?; catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); @@ -2842,13 +2876,15 @@ mod tests { let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); - schema_a.register_table("table_a".to_owned(), table_with_sequence(1, 1)?)?; + schema_a + .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?; catalog_a.register_schema("schema_a", Arc::new(schema_a)); ctx.register_catalog("catalog_a", Arc::new(catalog_a)); let catalog_b = MemoryCatalogProvider::new(); let schema_b = MemorySchemaProvider::new(); - schema_b.register_table("table_b".to_owned(), table_with_sequence(1, 2)?)?; + schema_b + .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?; catalog_b.register_schema("schema_b", Arc::new(schema_b)); ctx.register_catalog("catalog_b", Arc::new(catalog_b)); diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 04f340a9936..4fdf3346c9a 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -20,6 +20,7 @@ use crate::datasource::{MemTable, TableProvider}; use crate::error::Result; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; +use array::ArrayRef; use arrow::array::{self, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -154,6 +155,33 @@ pub fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } +/// Return a new table provider that has a single Int32 column with +/// values between `seq_start` and `seq_end` +pub fn table_with_sequence( + seq_start: i32, + seq_end: i32, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + +/// Return a RecordBatch with a single Int32 array with values (0..sz) +pub fn make_partition(sz: i32) -> RecordBatch { + let seq_start = 0; + let seq_end = sz; + let values = (seq_start..seq_end).collect::>(); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from(values)); + let arr = arr as ArrayRef; + + RecordBatch::try_new(schema, vec![arr]).unwrap() +} + pub mod user_defined; pub mod variable; From b923fdd47b115654837aa53da58fcc50b13cb7af Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 7 Apr 2021 09:24:31 -0400 Subject: [PATCH 2/2] ARROW-12254: [Rust][DataFusion] Stop polling limit input once limit is reached --- rust/datafusion/src/execution/context.rs | 59 ------------ rust/datafusion/src/physical_plan/limit.rs | 61 ++++++++++-- rust/datafusion/src/test/exec.rs | 102 +++++++++++++++++++++ rust/datafusion/src/test/mod.rs | 1 + 4 files changed, 156 insertions(+), 67 deletions(-) create mode 100644 rust/datafusion/src/test/exec.rs diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index d07994cc043..a8199665ecf 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -1823,65 +1823,6 @@ mod tests { Ok(()) } - #[tokio::test] - async fn limit_multi_batch() -> Result<()> { - let tmp_dir = TempDir::new()?; - let mut ctx = create_ctx(&tmp_dir, 1)?; - - let partitions = vec![ - vec![ - test::make_partition(0), - test::make_partition(1), - test::make_partition(2), - ], - vec![ - test::make_partition(1), - test::make_partition(2), - test::make_partition(3), - ], - vec![ - test::make_partition(2), - test::make_partition(3), - test::make_partition(4), - ], - vec![ - test::make_partition(3), - test::make_partition(4), - test::make_partition(5), - ], - vec![ - test::make_partition(4), - test::make_partition(5), - test::make_partition(6), - ], - vec![ - test::make_partition(5), - test::make_partition(6), - test::make_partition(7), - ], - ]; - let schema = partitions[0][0].schema(); - let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); - - ctx.register_table("t", provider).unwrap(); - - // select all rows - let results = plan_and_collect(&mut ctx, "SELECT i FROM t").await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, 63); - - for limit in 1..10 { - let query = format!("SELECT i FROM t limit {}", limit); - let results = plan_and_collect(&mut ctx, &query).await.unwrap(); - - let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); - assert_eq!(num_rows, limit, "mismatch with query {}", query); - } - - Ok(()) - } - #[tokio::test] async fn case_sensitive_identifiers_functions() { let mut ctx = ExecutionContext::new(); diff --git a/rust/datafusion/src/physical_plan/limit.rs b/rust/datafusion/src/physical_plan/limit.rs index 2f0eb682dd8..c091196483f 100644 --- a/rust/datafusion/src/physical_plan/limit.rs +++ b/rust/datafusion/src/physical_plan/limit.rs @@ -200,23 +200,31 @@ pub fn truncate_batch(batch: &RecordBatch, n: usize) -> RecordBatch { /// A Limit stream limits the stream to up to `limit` rows. struct LimitStream { + /// The maximum number of rows to produce limit: usize, - input: SendableRecordBatchStream, - // the current count + /// The input to read from. This is set to None once the limit is + /// reached to enable early termination + input: Option, + /// Copy of the input schema + schema: SchemaRef, + // the current number of rows which have been produced current_len: usize, } impl LimitStream { fn new(input: SendableRecordBatchStream, limit: usize) -> Self { + let schema = input.schema(); Self { limit, - input, + input: Some(input), + schema, current_len: 0, } } fn stream_limit(&mut self, batch: RecordBatch) -> Option { if self.current_len == self.limit { + self.input = None; // clear input so it can be dropped early None } else if self.current_len + batch.num_rows() <= self.limit { self.current_len += batch.num_rows(); @@ -224,6 +232,7 @@ impl LimitStream { } else { let batch_rows = self.limit - self.current_len; self.current_len = self.limit; + self.input = None; // clear input so it can be dropped early Some(truncate_batch(&batch, batch_rows)) } } @@ -236,23 +245,29 @@ impl Stream for LimitStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.input.poll_next_unpin(cx).map(|x| match x { - Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(), - other => other, - }) + match &mut self.input { + Some(input) => input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(), + other => other, + }), + // input has been cleared + None => Poll::Ready(None), + } } } impl RecordBatchStream for LimitStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.input.schema() + self.schema.clone() } } #[cfg(test)] mod tests { + use common::collect; + use super::*; use crate::physical_plan::common; use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; @@ -290,4 +305,34 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn limit_early_shutdown() -> Result<()> { + let batches = vec![ + test::make_partition(5), + test::make_partition(10), + test::make_partition(15), + test::make_partition(20), + test::make_partition(25), + ]; + let input = test::exec::TestStream::new(batches); + + let index = input.index(); + assert_eq!(index.value(), 0); + + // limit of six needs to consume the entire first record batch + // (5 rows) and 1 row from the second (1 row) + let limit_stream = LimitStream::new(Box::pin(input), 6); + assert_eq!(index.value(), 0); + + let results = collect(Box::pin(limit_stream)).await.unwrap(); + let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); + // Only 6 rows should have been produced + assert_eq!(num_rows, 6); + + // Only the first two batches should be consumed + assert_eq!(index.value(), 2); + + Ok(()) + } } diff --git a/rust/datafusion/src/test/exec.rs b/rust/datafusion/src/test/exec.rs new file mode 100644 index 00000000000..04cd29530c0 --- /dev/null +++ b/rust/datafusion/src/test/exec.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simple iterator over batches for use in testing + +use std::task::{Context, Poll}; + +use arrow::{ + datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, +}; +use futures::Stream; + +use crate::physical_plan::RecordBatchStream; + +/// Index into the data that has been returned so far +#[derive(Debug, Default, Clone)] +pub struct BatchIndex { + inner: std::sync::Arc>, +} + +impl BatchIndex { + /// Return the current index + pub fn value(&self) -> usize { + let inner = self.inner.lock().unwrap(); + *inner + } + + // increment the current index by one + pub fn incr(&self) { + let mut inner = self.inner.lock().unwrap(); + *inner += 1; + } +} + +/// Iterator over batches +#[derive(Debug, Default)] +pub(crate) struct TestStream { + /// Vector of record batches + data: Vec, + /// Index into the data that has been returned so far + index: BatchIndex, +} + +impl TestStream { + /// Create an iterator for a vector of record batches. Assumes at + /// least one entry in data (for the schema) + pub fn new(data: Vec) -> Self { + Self { + data, + ..Default::default() + } + } + + /// Return a handle to the index counter for this stream + pub fn index(&self) -> BatchIndex { + self.index.clone() + } +} + +impl Stream for TestStream { + type Item = ArrowResult; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + let next_batch = self.index.value(); + + Poll::Ready(if next_batch < self.data.len() { + let next_batch = self.index.value(); + self.index.incr(); + Some(Ok(self.data[next_batch].clone())) + } else { + None + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.data.len(), Some(self.data.len())) + } +} + +impl RecordBatchStream for TestStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.data[0].schema() + } +} diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 4fdf3346c9a..57736189481 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -182,6 +182,7 @@ pub fn make_partition(sz: i32) -> RecordBatch { RecordBatch::try_new(schema, vec![arr]).unwrap() } +pub mod exec; pub mod user_defined; pub mod variable;