diff --git a/rust/datafusion/src/execution/context.rs b/rust/datafusion/src/execution/context.rs index 4c419d983a6..a8199665ecf 100644 --- a/rust/datafusion/src/execution/context.rs +++ b/rust/datafusion/src/execution/context.rs @@ -1752,7 +1752,7 @@ mod tests { async fn limit() -> Result<()> { let tmp_dir = TempDir::new()?; let mut ctx = create_ctx(&tmp_dir, 1)?; - ctx.register_table("t", table_with_sequence(1, 1000).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1000).unwrap()) .unwrap(); let results = @@ -1788,30 +1788,18 @@ mod tests { Ok(()) } - /// Return a RecordBatch with a single Int32 array with values (0..sz) - fn make_partition(sz: i32) -> RecordBatch { - let seq_start = 0; - let seq_end = sz; - let values = (seq_start..seq_end).collect::>(); - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from(values)); - let arr = arr as ArrayRef; - - RecordBatch::try_new(schema, vec![arr]).unwrap() - } - #[tokio::test] async fn limit_multi_partitions() -> Result<()> { let tmp_dir = TempDir::new()?; let mut ctx = create_ctx(&tmp_dir, 1)?; let partitions = vec![ - vec![make_partition(0)], - vec![make_partition(1)], - vec![make_partition(2)], - vec![make_partition(3)], - vec![make_partition(4)], - vec![make_partition(5)], + vec![test::make_partition(0)], + vec![test::make_partition(1)], + vec![test::make_partition(2)], + vec![test::make_partition(3)], + vec![test::make_partition(4)], + vec![test::make_partition(5)], ]; let schema = partitions[0][0].schema(); let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap()); @@ -1838,7 +1826,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_functions() { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = vec![ @@ -1878,7 +1866,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0])); @@ -1918,7 +1906,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_aggregates() { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = vec![ @@ -1958,7 +1946,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { let mut ctx = ExecutionContext::new(); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // Note capitalizaton @@ -2356,19 +2344,6 @@ mod tests { Ok(()) } - fn table_with_sequence( - seq_start: i32, - seq_end: i32, - ) -> Result> { - let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); - let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); - let partitions = vec![vec![RecordBatch::try_new( - schema.clone(), - vec![arr as ArrayRef], - )?]]; - Ok(Arc::new(MemTable::try_new(schema, partitions)?)) - } - #[tokio::test] async fn information_schema_tables_not_exist_by_default() { let mut ctx = ExecutionContext::new(); @@ -2411,7 +2386,7 @@ mod tests { ); // Now, register an empty table - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = @@ -2431,7 +2406,7 @@ mod tests { assert_batches_sorted_eq!(expected, &result); // Newly added tables should appear - ctx.register_table("t2", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t2", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = @@ -2460,10 +2435,10 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); schema - .register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t2".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); @@ -2471,7 +2446,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema - .register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t3".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); catalog.register_schema("my_other_schema", Arc::new(schema)); ctx.register_catalog("my_other_catalog", Arc::new(catalog)); @@ -2503,7 +2478,7 @@ mod tests { async fn information_schema_show_tables_no_information_schema() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias @@ -2518,7 +2493,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); // use show tables alias @@ -2544,7 +2519,7 @@ mod tests { async fn information_schema_show_columns_no_information_schema() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") @@ -2558,7 +2533,7 @@ mod tests { async fn information_schema_show_columns_like_where() { let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let expected = @@ -2582,7 +2557,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t") @@ -2620,7 +2595,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t") @@ -2649,7 +2624,7 @@ mod tests { ExecutionConfig::new().with_information_schema(true), ); - ctx.register_table("t", table_with_sequence(1, 1).unwrap()) + ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t") @@ -2752,7 +2727,7 @@ mod tests { let schema = MemorySchemaProvider::new(); schema - .register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap()) + .register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap()) .unwrap(); schema @@ -2790,7 +2765,7 @@ mod tests { ); assert!(matches!( - ctx.register_table("test", table_with_sequence(1, 1)?), + ctx.register_table("test", test::table_with_sequence(1, 1)?), Err(DataFusionError::Plan(_)) )); @@ -2812,7 +2787,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); - schema.register_table("test".to_owned(), table_with_sequence(1, 1)?)?; + schema.register_table("test".to_owned(), test::table_with_sequence(1, 1)?)?; catalog.register_schema("my_schema", Arc::new(schema)); ctx.register_catalog("my_catalog", Arc::new(catalog)); @@ -2842,13 +2817,15 @@ mod tests { let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); - schema_a.register_table("table_a".to_owned(), table_with_sequence(1, 1)?)?; + schema_a + .register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?; catalog_a.register_schema("schema_a", Arc::new(schema_a)); ctx.register_catalog("catalog_a", Arc::new(catalog_a)); let catalog_b = MemoryCatalogProvider::new(); let schema_b = MemorySchemaProvider::new(); - schema_b.register_table("table_b".to_owned(), table_with_sequence(1, 2)?)?; + schema_b + .register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?; catalog_b.register_schema("schema_b", Arc::new(schema_b)); ctx.register_catalog("catalog_b", Arc::new(catalog_b)); diff --git a/rust/datafusion/src/physical_plan/limit.rs b/rust/datafusion/src/physical_plan/limit.rs index 2f0eb682dd8..c091196483f 100644 --- a/rust/datafusion/src/physical_plan/limit.rs +++ b/rust/datafusion/src/physical_plan/limit.rs @@ -200,23 +200,31 @@ pub fn truncate_batch(batch: &RecordBatch, n: usize) -> RecordBatch { /// A Limit stream limits the stream to up to `limit` rows. struct LimitStream { + /// The maximum number of rows to produce limit: usize, - input: SendableRecordBatchStream, - // the current count + /// The input to read from. This is set to None once the limit is + /// reached to enable early termination + input: Option, + /// Copy of the input schema + schema: SchemaRef, + // the current number of rows which have been produced current_len: usize, } impl LimitStream { fn new(input: SendableRecordBatchStream, limit: usize) -> Self { + let schema = input.schema(); Self { limit, - input, + input: Some(input), + schema, current_len: 0, } } fn stream_limit(&mut self, batch: RecordBatch) -> Option { if self.current_len == self.limit { + self.input = None; // clear input so it can be dropped early None } else if self.current_len + batch.num_rows() <= self.limit { self.current_len += batch.num_rows(); @@ -224,6 +232,7 @@ impl LimitStream { } else { let batch_rows = self.limit - self.current_len; self.current_len = self.limit; + self.input = None; // clear input so it can be dropped early Some(truncate_batch(&batch, batch_rows)) } } @@ -236,23 +245,29 @@ impl Stream for LimitStream { mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll> { - self.input.poll_next_unpin(cx).map(|x| match x { - Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(), - other => other, - }) + match &mut self.input { + Some(input) => input.poll_next_unpin(cx).map(|x| match x { + Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(), + other => other, + }), + // input has been cleared + None => Poll::Ready(None), + } } } impl RecordBatchStream for LimitStream { /// Get the schema fn schema(&self) -> SchemaRef { - self.input.schema() + self.schema.clone() } } #[cfg(test)] mod tests { + use common::collect; + use super::*; use crate::physical_plan::common; use crate::physical_plan::csv::{CsvExec, CsvReadOptions}; @@ -290,4 +305,34 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn limit_early_shutdown() -> Result<()> { + let batches = vec![ + test::make_partition(5), + test::make_partition(10), + test::make_partition(15), + test::make_partition(20), + test::make_partition(25), + ]; + let input = test::exec::TestStream::new(batches); + + let index = input.index(); + assert_eq!(index.value(), 0); + + // limit of six needs to consume the entire first record batch + // (5 rows) and 1 row from the second (1 row) + let limit_stream = LimitStream::new(Box::pin(input), 6); + assert_eq!(index.value(), 0); + + let results = collect(Box::pin(limit_stream)).await.unwrap(); + let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum(); + // Only 6 rows should have been produced + assert_eq!(num_rows, 6); + + // Only the first two batches should be consumed + assert_eq!(index.value(), 2); + + Ok(()) + } } diff --git a/rust/datafusion/src/test/exec.rs b/rust/datafusion/src/test/exec.rs new file mode 100644 index 00000000000..04cd29530c0 --- /dev/null +++ b/rust/datafusion/src/test/exec.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Simple iterator over batches for use in testing + +use std::task::{Context, Poll}; + +use arrow::{ + datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, +}; +use futures::Stream; + +use crate::physical_plan::RecordBatchStream; + +/// Index into the data that has been returned so far +#[derive(Debug, Default, Clone)] +pub struct BatchIndex { + inner: std::sync::Arc>, +} + +impl BatchIndex { + /// Return the current index + pub fn value(&self) -> usize { + let inner = self.inner.lock().unwrap(); + *inner + } + + // increment the current index by one + pub fn incr(&self) { + let mut inner = self.inner.lock().unwrap(); + *inner += 1; + } +} + +/// Iterator over batches +#[derive(Debug, Default)] +pub(crate) struct TestStream { + /// Vector of record batches + data: Vec, + /// Index into the data that has been returned so far + index: BatchIndex, +} + +impl TestStream { + /// Create an iterator for a vector of record batches. Assumes at + /// least one entry in data (for the schema) + pub fn new(data: Vec) -> Self { + Self { + data, + ..Default::default() + } + } + + /// Return a handle to the index counter for this stream + pub fn index(&self) -> BatchIndex { + self.index.clone() + } +} + +impl Stream for TestStream { + type Item = ArrowResult; + + fn poll_next( + self: std::pin::Pin<&mut Self>, + _: &mut Context<'_>, + ) -> Poll> { + let next_batch = self.index.value(); + + Poll::Ready(if next_batch < self.data.len() { + let next_batch = self.index.value(); + self.index.incr(); + Some(Ok(self.data[next_batch].clone())) + } else { + None + }) + } + + fn size_hint(&self) -> (usize, Option) { + (self.data.len(), Some(self.data.len())) + } +} + +impl RecordBatchStream for TestStream { + /// Get the schema + fn schema(&self) -> SchemaRef { + self.data[0].schema() + } +} diff --git a/rust/datafusion/src/test/mod.rs b/rust/datafusion/src/test/mod.rs index 04f340a9936..57736189481 100644 --- a/rust/datafusion/src/test/mod.rs +++ b/rust/datafusion/src/test/mod.rs @@ -20,6 +20,7 @@ use crate::datasource::{MemTable, TableProvider}; use crate::error::Result; use crate::logical_plan::{LogicalPlan, LogicalPlanBuilder}; +use array::ArrayRef; use arrow::array::{self, Int32Array}; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use arrow::record_batch::RecordBatch; @@ -154,6 +155,34 @@ pub fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } +/// Return a new table provider that has a single Int32 column with +/// values between `seq_start` and `seq_end` +pub fn table_with_sequence( + seq_start: i32, + seq_end: i32, +) -> Result> { + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::>())); + let partitions = vec![vec![RecordBatch::try_new( + schema.clone(), + vec![arr as ArrayRef], + )?]]; + Ok(Arc::new(MemTable::try_new(schema, partitions)?)) +} + +/// Return a RecordBatch with a single Int32 array with values (0..sz) +pub fn make_partition(sz: i32) -> RecordBatch { + let seq_start = 0; + let seq_end = sz; + let values = (seq_start..seq_end).collect::>(); + let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)])); + let arr = Arc::new(Int32Array::from(values)); + let arr = arr as ArrayRef; + + RecordBatch::try_new(schema, vec![arr]).unwrap() +} + +pub mod exec; pub mod user_defined; pub mod variable;