Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 30 additions & 53 deletions rust/datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1752,7 +1752,7 @@ mod tests {
async fn limit() -> Result<()> {
let tmp_dir = TempDir::new()?;
let mut ctx = create_ctx(&tmp_dir, 1)?;
ctx.register_table("t", table_with_sequence(1, 1000).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1000).unwrap())
.unwrap();

let results =
Expand Down Expand Up @@ -1788,30 +1788,18 @@ mod tests {
Ok(())
}

/// Return a RecordBatch with a single Int32 array with values (0..sz)
fn make_partition(sz: i32) -> RecordBatch {
let seq_start = 0;
let seq_end = sz;
let values = (seq_start..seq_end).collect::<Vec<_>>();
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)]));
let arr = Arc::new(Int32Array::from(values));
let arr = arr as ArrayRef;

RecordBatch::try_new(schema, vec![arr]).unwrap()
}

#[tokio::test]
async fn limit_multi_partitions() -> Result<()> {
let tmp_dir = TempDir::new()?;
let mut ctx = create_ctx(&tmp_dir, 1)?;

let partitions = vec![
vec![make_partition(0)],
vec![make_partition(1)],
vec![make_partition(2)],
vec![make_partition(3)],
vec![make_partition(4)],
vec![make_partition(5)],
vec![test::make_partition(0)],
vec![test::make_partition(1)],
vec![test::make_partition(2)],
vec![test::make_partition(3)],
vec![test::make_partition(4)],
vec![test::make_partition(5)],
];
let schema = partitions[0][0].schema();
let provider = Arc::new(MemTable::try_new(schema, partitions).unwrap());
Expand All @@ -1838,7 +1826,7 @@ mod tests {
#[tokio::test]
async fn case_sensitive_identifiers_functions() {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let expected = vec![
Expand Down Expand Up @@ -1878,7 +1866,7 @@ mod tests {
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
Expand Down Expand Up @@ -1918,7 +1906,7 @@ mod tests {
#[tokio::test]
async fn case_sensitive_identifiers_aggregates() {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let expected = vec![
Expand Down Expand Up @@ -1958,7 +1946,7 @@ mod tests {
#[tokio::test]
async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> {
let mut ctx = ExecutionContext::new();
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

// Note capitalizaton
Expand Down Expand Up @@ -2356,19 +2344,6 @@ mod tests {
Ok(())
}

fn table_with_sequence(
seq_start: i32,
seq_end: i32,
) -> Result<Arc<dyn TableProvider>> {
let schema = Arc::new(Schema::new(vec![Field::new("i", DataType::Int32, true)]));
let arr = Arc::new(Int32Array::from((seq_start..=seq_end).collect::<Vec<_>>()));
let partitions = vec![vec![RecordBatch::try_new(
schema.clone(),
vec![arr as ArrayRef],
)?]];
Ok(Arc::new(MemTable::try_new(schema, partitions)?))
}

#[tokio::test]
async fn information_schema_tables_not_exist_by_default() {
let mut ctx = ExecutionContext::new();
Expand Down Expand Up @@ -2411,7 +2386,7 @@ mod tests {
);

// Now, register an empty table
ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let result =
Expand All @@ -2431,7 +2406,7 @@ mod tests {
assert_batches_sorted_eq!(expected, &result);

// Newly added tables should appear
ctx.register_table("t2", table_with_sequence(1, 1).unwrap())
ctx.register_table("t2", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let result =
Expand Down Expand Up @@ -2460,18 +2435,18 @@ mod tests {
let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap())
.register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
schema
.register_table("t2".to_owned(), table_with_sequence(1, 1).unwrap())
.register_table("t2".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog.register_schema("my_schema", Arc::new(schema));
ctx.register_catalog("my_catalog", Arc::new(catalog));

let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema
.register_table("t3".to_owned(), table_with_sequence(1, 1).unwrap())
.register_table("t3".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();
catalog.register_schema("my_other_schema", Arc::new(schema));
ctx.register_catalog("my_other_catalog", Arc::new(catalog));
Expand Down Expand Up @@ -2503,7 +2478,7 @@ mod tests {
async fn information_schema_show_tables_no_information_schema() {
let mut ctx = ExecutionContext::with_config(ExecutionConfig::new());

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

// use show tables alias
Expand All @@ -2518,7 +2493,7 @@ mod tests {
ExecutionConfig::new().with_information_schema(true),
);

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

// use show tables alias
Expand All @@ -2544,7 +2519,7 @@ mod tests {
async fn information_schema_show_columns_no_information_schema() {
let mut ctx = ExecutionContext::with_config(ExecutionConfig::new());

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let err = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t")
Expand All @@ -2558,7 +2533,7 @@ mod tests {
async fn information_schema_show_columns_like_where() {
let mut ctx = ExecutionContext::with_config(ExecutionConfig::new());

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let expected =
Expand All @@ -2582,7 +2557,7 @@ mod tests {
ExecutionConfig::new().with_information_schema(true),
);

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM t")
Expand Down Expand Up @@ -2620,7 +2595,7 @@ mod tests {
ExecutionConfig::new().with_information_schema(true),
);

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let result = plan_and_collect(&mut ctx, "SHOW FULL COLUMNS FROM t")
Expand Down Expand Up @@ -2649,7 +2624,7 @@ mod tests {
ExecutionConfig::new().with_information_schema(true),
);

ctx.register_table("t", table_with_sequence(1, 1).unwrap())
ctx.register_table("t", test::table_with_sequence(1, 1).unwrap())
.unwrap();

let result = plan_and_collect(&mut ctx, "SHOW COLUMNS FROM public.t")
Expand Down Expand Up @@ -2752,7 +2727,7 @@ mod tests {
let schema = MemorySchemaProvider::new();

schema
.register_table("t1".to_owned(), table_with_sequence(1, 1).unwrap())
.register_table("t1".to_owned(), test::table_with_sequence(1, 1).unwrap())
.unwrap();

schema
Expand Down Expand Up @@ -2790,7 +2765,7 @@ mod tests {
);

assert!(matches!(
ctx.register_table("test", table_with_sequence(1, 1)?),
ctx.register_table("test", test::table_with_sequence(1, 1)?),
Err(DataFusionError::Plan(_))
));

Expand All @@ -2812,7 +2787,7 @@ mod tests {

let catalog = MemoryCatalogProvider::new();
let schema = MemorySchemaProvider::new();
schema.register_table("test".to_owned(), table_with_sequence(1, 1)?)?;
schema.register_table("test".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog.register_schema("my_schema", Arc::new(schema));
ctx.register_catalog("my_catalog", Arc::new(catalog));

Expand Down Expand Up @@ -2842,13 +2817,15 @@ mod tests {

let catalog_a = MemoryCatalogProvider::new();
let schema_a = MemorySchemaProvider::new();
schema_a.register_table("table_a".to_owned(), table_with_sequence(1, 1)?)?;
schema_a
.register_table("table_a".to_owned(), test::table_with_sequence(1, 1)?)?;
catalog_a.register_schema("schema_a", Arc::new(schema_a));
ctx.register_catalog("catalog_a", Arc::new(catalog_a));

let catalog_b = MemoryCatalogProvider::new();
let schema_b = MemorySchemaProvider::new();
schema_b.register_table("table_b".to_owned(), table_with_sequence(1, 2)?)?;
schema_b
.register_table("table_b".to_owned(), test::table_with_sequence(1, 2)?)?;
catalog_b.register_schema("schema_b", Arc::new(schema_b));
ctx.register_catalog("catalog_b", Arc::new(catalog_b));

Expand Down
61 changes: 53 additions & 8 deletions rust/datafusion/src/physical_plan/limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -200,30 +200,39 @@ pub fn truncate_batch(batch: &RecordBatch, n: usize) -> RecordBatch {

/// A Limit stream limits the stream to up to `limit` rows.
struct LimitStream {
/// The maximum number of rows to produce
limit: usize,
input: SendableRecordBatchStream,
// the current count
/// The input to read from. This is set to None once the limit is
Copy link
Contributor

@Dandandan Dandandan Apr 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this necessary? It can be based on self.current_len == self.limit or otherwise a boolean like limit_exhausted?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

haha 👍 😆

/// reached to enable early termination
input: Option<SendableRecordBatchStream>,
/// Copy of the input schema
schema: SchemaRef,
// the current number of rows which have been produced
current_len: usize,
}

impl LimitStream {
fn new(input: SendableRecordBatchStream, limit: usize) -> Self {
let schema = input.schema();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The actual code change for this PR is very small -- the rest of the changes are related to writing a proper test for it

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a vague memory that FusedStream may have something to do with this property (although /noideadog)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point.

One benefit of the this PR over fuse() is that this PR will actually drop the input stream (freeing resources) in addition to not calling the input stream again: https://docs.rs/futures-util/0.3.13/src/futures_util/stream/stream/fuse.rs.html#10

Self {
limit,
input,
input: Some(input),
schema,
current_len: 0,
}
}

fn stream_limit(&mut self, batch: RecordBatch) -> Option<RecordBatch> {
if self.current_len == self.limit {
self.input = None; // clear input so it can be dropped early
None
} else if self.current_len + batch.num_rows() <= self.limit {
self.current_len += batch.num_rows();
Some(batch)
} else {
let batch_rows = self.limit - self.current_len;
self.current_len = self.limit;
self.input = None; // clear input so it can be dropped early
Some(truncate_batch(&batch, batch_rows))
}
}
Expand All @@ -236,23 +245,29 @@ impl Stream for LimitStream {
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Self::Item>> {
self.input.poll_next_unpin(cx).map(|x| match x {
Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
other => other,
})
match &mut self.input {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just compare self.current_len == self.limit and short-cirtcuit before polling the wrapped stream, instead of the Option plumbing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was that the option plumbing actually drops the input, freeing its resources when the limit has been hit, rather than waiting for the execution to be complete.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we do this in other places too? Isn't a SendableRecordBatchStream a small struct?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't a SendableRecordBatchStream a small struct?

It is a trait, so there are various things that implement it. Some, like the ParquetStream

Ok(Box::pin(ParquetStream {
, could have substantial resources

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand it now, it can consist of the whole tree of dependent streams. Probably still not a big resource hog but more than a few bytes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine it is actually a subquery with a group by hash or join with a large hash table :) It may actually be hanging on to a substantial amount of memory I suspect

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, yeah that's right!

Some(input) => input.poll_next_unpin(cx).map(|x| match x {
Some(Ok(batch)) => Ok(self.stream_limit(batch)).transpose(),
other => other,
}),
// input has been cleared
None => Poll::Ready(None),
}
}
}

impl RecordBatchStream for LimitStream {
/// Get the schema
fn schema(&self) -> SchemaRef {
self.input.schema()
self.schema.clone()
}
}

#[cfg(test)]
mod tests {

use common::collect;

use super::*;
use crate::physical_plan::common;
use crate::physical_plan::csv::{CsvExec, CsvReadOptions};
Expand Down Expand Up @@ -290,4 +305,34 @@ mod tests {

Ok(())
}

#[tokio::test]
async fn limit_early_shutdown() -> Result<()> {
let batches = vec![
test::make_partition(5),
test::make_partition(10),
test::make_partition(15),
test::make_partition(20),
test::make_partition(25),
];
let input = test::exec::TestStream::new(batches);

let index = input.index();
assert_eq!(index.value(), 0);

// limit of six needs to consume the entire first record batch
// (5 rows) and 1 row from the second (1 row)
let limit_stream = LimitStream::new(Box::pin(input), 6);
assert_eq!(index.value(), 0);

let results = collect(Box::pin(limit_stream)).await.unwrap();
let num_rows: usize = results.into_iter().map(|b| b.num_rows()).sum();
// Only 6 rows should have been produced
assert_eq!(num_rows, 6);

// Only the first two batches should be consumed
assert_eq!(index.value(), 2);

Ok(())
}
}
Loading