From 60545c84e1e3fa032d433095cf822db38c983487 Mon Sep 17 00:00:00 2001 From: Wang Date: Fri, 11 Mar 2022 15:25:55 +0800 Subject: [PATCH] Refactor ExecutionContext and related conf to support multi-tenancy configurations - Part 1 --- ballista/rust/client/src/context.rs | 6 +- .../src/execution_plans/distributed_query.rs | 4 +- .../src/execution_plans/shuffle_reader.rs | 4 +- .../src/execution_plans/shuffle_writer.rs | 21 +- .../src/execution_plans/unresolved_shuffle.rs | 4 +- .../rust/core/src/serde/logical_plan/mod.rs | 8 +- ballista/rust/core/src/serde/mod.rs | 23 +- .../src/serde/physical_plan/from_proto.rs | 24 +- .../rust/core/src/serde/physical_plan/mod.rs | 10 +- ballista/rust/core/src/utils.rs | 10 +- ballista/rust/executor/src/collect.rs | 6 +- ballista/rust/executor/src/execution_loop.rs | 17 ++ ballista/rust/executor/src/executor.rs | 16 +- ballista/rust/executor/src/executor_server.rs | 17 ++ ballista/rust/executor/src/main.rs | 4 +- ballista/rust/executor/src/standalone.rs | 4 +- ballista/rust/scheduler/src/main.rs | 6 +- ballista/rust/scheduler/src/planner.rs | 4 +- .../scheduler/src/scheduler_server/grpc.rs | 4 +- .../scheduler/src/scheduler_server/mod.rs | 16 +- .../scheduler_server/query_stage_scheduler.rs | 6 +- ballista/rust/scheduler/src/standalone.rs | 4 +- ballista/rust/scheduler/src/state/mod.rs | 4 +- .../scheduler/src/state/persistent_state.rs | 6 +- ballista/rust/scheduler/src/test_utils.rs | 9 +- benchmarks/src/bin/nyctaxi.rs | 14 +- benchmarks/src/bin/tpch.rs | 36 ++- datafusion-cli/src/context.rs | 8 +- datafusion-cli/src/main.rs | 8 +- datafusion-examples/examples/avro_sql.rs | 2 +- datafusion-examples/examples/csv_sql.rs | 2 +- .../examples/custom_datasource.rs | 6 +- datafusion-examples/examples/dataframe.rs | 2 +- .../examples/dataframe_in_memory.rs | 2 +- datafusion-examples/examples/flight_server.rs | 2 +- datafusion-examples/examples/memtable.rs | 4 +- datafusion-examples/examples/parquet_sql.rs | 4 +- .../examples/parquet_sql_multiple_files.rs | 2 +- datafusion-examples/examples/simple_udaf.rs | 6 +- datafusion-examples/examples/simple_udf.rs | 4 +- datafusion/Cargo.toml | 1 + datafusion/benches/aggregate_query_sql.rs | 8 +- datafusion/benches/filter_query_sql.rs | 8 +- datafusion/benches/math_query_sql.rs | 8 +- datafusion/benches/parquet_query_sql.rs | 4 +- datafusion/benches/physical_plan.rs | 21 +- datafusion/benches/sort_limit_query_sql.rs | 16 +- datafusion/benches/window_query_sql.rs | 8 +- datafusion/src/catalog/schema.rs | 4 +- datafusion/src/dataframe.rs | 127 ++++----- datafusion/src/datasource/file_format/avro.rs | 47 ++-- datafusion/src/datasource/file_format/csv.rs | 18 +- datafusion/src/datasource/file_format/json.rs | 18 +- .../src/datasource/file_format/parquet.rs | 48 ++-- datafusion/src/datasource/listing/helpers.rs | 4 +- datafusion/src/datasource/memory.rs | 26 +- datafusion/src/execution/context.rs | 259 ++++++++++++------ datafusion/src/lib.rs | 6 +- .../aggregate_statistics.rs | 21 +- .../physical_optimizer/coalesce_batches.rs | 2 +- .../hash_build_probe_order.rs | 10 +- .../src/physical_optimizer/merge_exec.rs | 2 +- .../src/physical_optimizer/optimizer.rs | 4 +- .../src/physical_optimizer/repartition.rs | 6 +- datafusion/src/physical_optimizer/utils.rs | 6 +- datafusion/src/physical_plan/analyze.rs | 12 +- .../src/physical_plan/coalesce_batches.rs | 12 +- .../src/physical_plan/coalesce_partitions.rs | 19 +- datafusion/src/physical_plan/common.rs | 6 +- datafusion/src/physical_plan/cross_join.rs | 8 +- datafusion/src/physical_plan/empty.rs | 22 +- datafusion/src/physical_plan/explain.rs | 4 +- .../src/physical_plan/file_format/avro.rs | 8 +- .../src/physical_plan/file_format/csv.rs | 40 +-- .../src/physical_plan/file_format/json.rs | 22 +- .../src/physical_plan/file_format/parquet.rs | 42 +-- datafusion/src/physical_plan/filter.rs | 12 +- .../src/physical_plan/hash_aggregate.rs | 27 +- datafusion/src/physical_plan/hash_join.rs | 111 ++++---- datafusion/src/physical_plan/limit.rs | 16 +- datafusion/src/physical_plan/memory.rs | 15 +- datafusion/src/physical_plan/mod.rs | 28 +- datafusion/src/physical_plan/planner.rs | 145 +++++----- datafusion/src/physical_plan/projection.rs | 12 +- datafusion/src/physical_plan/repartition.rs | 64 +++-- datafusion/src/physical_plan/sorts/sort.rs | 43 +-- .../sorts/sort_preserving_merge.rs | 113 ++++---- datafusion/src/physical_plan/union.rs | 12 +- datafusion/src/physical_plan/values.rs | 4 +- datafusion/src/physical_plan/windows/mod.rs | 12 +- .../physical_plan/windows/window_agg_exec.rs | 6 +- datafusion/src/prelude.rs | 2 +- datafusion/src/test/exec.rs | 15 +- datafusion/tests/custom_sources.rs | 17 +- datafusion/tests/dataframe.rs | 8 +- datafusion/tests/dataframe_functions.rs | 4 +- datafusion/tests/merge_fuzz.rs | 22 +- datafusion/tests/order_spill_fuzz.rs | 10 +- datafusion/tests/parquet_pruning.rs | 16 +- datafusion/tests/path_partition.rs | 18 +- datafusion/tests/provider_filter_pushdown.rs | 7 +- datafusion/tests/sql/aggregates.rs | 118 ++++---- datafusion/tests/sql/avro.rs | 14 +- datafusion/tests/sql/create_drop.rs | 12 +- datafusion/tests/sql/errors.rs | 12 +- datafusion/tests/sql/explain.rs | 4 +- datafusion/tests/sql/explain_analyze.rs | 48 ++-- datafusion/tests/sql/expr.rs | 70 ++--- datafusion/tests/sql/functions.rs | 18 +- datafusion/tests/sql/group_by.rs | 62 ++--- datafusion/tests/sql/information_schema.rs | 59 ++-- datafusion/tests/sql/intersection.rs | 16 +- datafusion/tests/sql/joins.rs | 136 ++++----- datafusion/tests/sql/limit.rs | 16 +- datafusion/tests/sql/mod.rs | 57 ++-- datafusion/tests/sql/order.rs | 26 +- datafusion/tests/sql/parquet.rs | 18 +- datafusion/tests/sql/partitioned_csv.rs | 8 +- datafusion/tests/sql/predicates.rs | 72 ++--- datafusion/tests/sql/projection.rs | 23 +- datafusion/tests/sql/references.rs | 18 +- datafusion/tests/sql/select.rs | 142 +++++----- datafusion/tests/sql/timestamp.rs | 120 ++++---- datafusion/tests/sql/udf.rs | 8 +- datafusion/tests/sql/unicode.rs | 2 +- datafusion/tests/sql/union.rs | 14 +- datafusion/tests/sql/window.rs | 16 +- datafusion/tests/statistics.rs | 10 +- datafusion/tests/user_defined_plan.rs | 42 ++- docs/source/python/api/execution_context.rst | 4 +- docs/source/python/index.rst | 2 +- .../user-guide/distributed/clients/rust.md | 2 +- docs/source/user-guide/example-usage.md | 4 +- docs/source/user-guide/library.md | 2 +- 134 files changed, 1678 insertions(+), 1452 deletions(-) diff --git a/ballista/rust/client/src/context.rs b/ballista/rust/client/src/context.rs index 8175a69837737..4a5fe6d30bfe7 100644 --- a/ballista/rust/client/src/context.rs +++ b/ballista/rust/client/src/context.rs @@ -34,7 +34,7 @@ use datafusion::datasource::TableProvider; use datafusion::error::{DataFusionError, Result}; use datafusion::logical_plan::{CreateExternalTable, LogicalPlan, TableScan}; use datafusion::prelude::{ - AvroReadOptions, CsvReadOptions, ExecutionConfig, ExecutionContext, + AvroReadOptions, CsvReadOptions, SessionConfig, SessionContext, }; use datafusion::sql::parser::{DFParser, FileType, Statement as DFStatement}; @@ -304,8 +304,8 @@ impl BallistaContext { // the show tables、 show columns sql can not run at scheduler because the tables is store at client if is_show { let state = self.state.lock(); - ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema( + ctx = SessionContext::with_config( + SessionConfig::new().with_information_schema( state.config.default_with_information_schema(), ), ); diff --git a/ballista/rust/core/src/execution_plans/distributed_query.rs b/ballista/rust/core/src/execution_plans/distributed_query.rs index a226622fe9db8..d4daeb26c42e8 100644 --- a/ballista/rust/core/src/execution_plans/distributed_query.rs +++ b/ballista/rust/core/src/execution_plans/distributed_query.rs @@ -42,7 +42,7 @@ use datafusion::physical_plan::{ use crate::serde::{AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec}; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use futures::future; use futures::StreamExt; use log::{error, info}; @@ -150,7 +150,7 @@ impl ExecutionPlan for DistributedQueryExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert_eq!(0, partition); diff --git a/ballista/rust/core/src/execution_plans/shuffle_reader.rs b/ballista/rust/core/src/execution_plans/shuffle_reader.rs index 3bebcd12e1558..aeabc72a5f917 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_reader.rs @@ -25,7 +25,6 @@ use crate::utils::WrappedStream; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::metrics::{ ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, @@ -39,6 +38,7 @@ use datafusion::{ }; use futures::{future, StreamExt}; +use datafusion::execution::context::TaskContext; use log::info; /// ShuffleReaderExec reads partitions that have already been materialized by a ShuffleWriterExec @@ -106,7 +106,7 @@ impl ExecutionPlan for ShuffleReaderExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { info!("ShuffleReaderExec::execute({})", partition); diff --git a/ballista/rust/core/src/execution_plans/shuffle_writer.rs b/ballista/rust/core/src/execution_plans/shuffle_writer.rs index b80fc8492083a..5f4d67f976aaf 100644 --- a/ballista/rust/core/src/execution_plans/shuffle_writer.rs +++ b/ballista/rust/core/src/execution_plans/shuffle_writer.rs @@ -42,7 +42,6 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::common::IPCWriter; use datafusion::physical_plan::hash_utils::create_hashes; use datafusion::physical_plan::memory::MemoryStream; @@ -55,6 +54,7 @@ use datafusion::physical_plan::{ }; use futures::StreamExt; +use datafusion::execution::context::TaskContext; use log::{debug, info}; /// ShuffleWriterExec represents a section of a query plan that has consistent partitioning and @@ -138,11 +138,11 @@ impl ShuffleWriterExec { pub async fn execute_shuffle_write( &self, input_partition: usize, - runtime: Arc, + context: Arc, ) -> Result> { let now = Instant::now(); - let mut stream = self.plan.execute(input_partition, runtime).await?; + let mut stream = self.plan.execute(input_partition, context).await?; let mut path = PathBuf::from(&self.work_dir); path.push(&self.job_id); @@ -358,9 +358,9 @@ impl ExecutionPlan for ShuffleWriterExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let part_loc = self.execute_shuffle_write(partition, runtime).await?; + let part_loc = self.execute_shuffle_write(partition, context).await?; // build metadata result batch let num_writers = part_loc.len(); @@ -448,11 +448,13 @@ mod tests { use datafusion::physical_plan::expressions::Column; use datafusion::physical_plan::memory::MemoryExec; + use datafusion::prelude::SessionContext; use tempfile::TempDir; #[tokio::test] async fn test() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let input_plan = Arc::new(CoalescePartitionsExec::new(create_input_plan()?)); let work_dir = TempDir::new()?; @@ -463,7 +465,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0, runtime).await?; + let mut stream = query_stage.execute(0, task_ctx).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; @@ -506,7 +508,8 @@ mod tests { #[tokio::test] async fn test_partitioned() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let input_plan = create_input_plan()?; let work_dir = TempDir::new()?; @@ -517,7 +520,7 @@ mod tests { work_dir.into_path().to_str().unwrap().to_owned(), Some(Partitioning::Hash(vec![Arc::new(Column::new("a", 0))], 2)), )?; - let mut stream = query_stage.execute(0, runtime).await?; + let mut stream = query_stage.execute(0, task_ctx).await?; let batches = utils::collect_stream(&mut stream) .await .map_err(|e| DataFusionError::Execution(format!("{:?}", e)))?; diff --git a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs index 418546aa389cd..868620be0f358 100644 --- a/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs +++ b/ballista/rust/core/src/execution_plans/unresolved_shuffle.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use async_trait::async_trait; use datafusion::arrow::datatypes::SchemaRef; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -104,7 +104,7 @@ impl ExecutionPlan for UnresolvedShuffleExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::Plan( "Ballista UnresolvedShuffleExec does not support execution".to_owned(), diff --git a/ballista/rust/core/src/serde/logical_plan/mod.rs b/ballista/rust/core/src/serde/logical_plan/mod.rs index 4970cd600a5a7..737979355afe8 100644 --- a/ballista/rust/core/src/serde/logical_plan/mod.rs +++ b/ballista/rust/core/src/serde/logical_plan/mod.rs @@ -36,7 +36,7 @@ use datafusion::logical_plan::{ Column, CreateExternalTable, CrossJoin, Expr, JoinConstraint, Limit, LogicalPlan, LogicalPlanBuilder, Repartition, TableScan, Values, }; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use prost::bytes::BufMut; use prost::Message; @@ -70,7 +70,7 @@ impl AsLogicalPlan for LogicalPlanNode { fn try_into_logical_plan( &self, - ctx: &ExecutionContext, + ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, ) -> Result { let plan = self.logical_plan_type.as_ref().ok_or_else(|| { @@ -920,7 +920,7 @@ mod roundtrip_tests { roundtrip_test!($initial_struct, protobuf::LogicalPlanNode, $struct_type); }; ($initial_struct:ident) => { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, @@ -1252,7 +1252,7 @@ mod roundtrip_tests { #[tokio::test] async fn roundtrip_logical_plan_custom_ctx() -> Result<()> { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let codec: BallistaCodec = BallistaCodec::default(); let custom_object_store = Arc::new(TestObjectStore {}); diff --git a/ballista/rust/core/src/serde/mod.rs b/ballista/rust/core/src/serde/mod.rs index b85a957d43f6e..cc1bbb4174295 100644 --- a/ballista/rust/core/src/serde/mod.rs +++ b/ballista/rust/core/src/serde/mod.rs @@ -30,7 +30,7 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use datafusion::logical_plan::plan::Extension; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use prost::Message; // include the generated protobuf source as a submodule @@ -67,7 +67,7 @@ pub trait AsLogicalPlan: Debug + Send + Sync + Clone { fn try_into_logical_plan( &self, - ctx: &ExecutionContext, + ctx: &SessionContext, extension_codec: &dyn LogicalExtensionCodec, ) -> Result; @@ -130,7 +130,7 @@ pub trait AsExecutionPlan: Debug + Send + Sync + Clone { fn try_into_physical_plan( &self, - ctx: &ExecutionContext, + ctx: &SessionContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result, BallistaError>; @@ -345,8 +345,7 @@ mod tests { use datafusion::arrow::datatypes::SchemaRef; use datafusion::datasource::object_store::local::LocalFileSystem; use datafusion::error::DataFusionError; - use datafusion::execution::context::{ExecutionContextState, QueryPlanner}; - use datafusion::execution::runtime_env::RuntimeEnv; + use datafusion::execution::context::{QueryPlanner, SessionState, TaskContext}; use datafusion::logical_plan::plan::Extension; use datafusion::logical_plan::{ col, DFSchemaRef, Expr, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, @@ -357,7 +356,7 @@ mod tests { DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, SendableRecordBatchStream, Statistics, }; - use datafusion::prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}; + use datafusion::prelude::{CsvReadOptions, SessionConfig, SessionContext}; use prost::Message; use std::any::Any; @@ -512,7 +511,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> datafusion::error::Result { Err(DataFusionError::NotImplemented( "not implemented".to_string(), @@ -548,7 +547,7 @@ mod tests { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> datafusion::error::Result>> { Ok( if let Some(topk_node) = node.as_any().downcast_ref::() { @@ -575,7 +574,7 @@ mod tests { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> datafusion::error::Result> { // Teach the default physical planner how to plan TopK nodes. let physical_planner = @@ -584,7 +583,7 @@ mod tests { )]); // Delegate most work of physical planning to the default physical planner physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } @@ -694,9 +693,9 @@ mod tests { async fn test_extension_plan() -> crate::error::Result<()> { let store = Arc::new(LocalFileSystem {}); let config = - ExecutionConfig::new().with_query_planner(Arc::new(TopKQueryPlanner {})); + SessionConfig::new().with_query_planner(Arc::new(TopKQueryPlanner {})); - let ctx = ExecutionContext::with_config(config); + let ctx = SessionContext::with_config(config); let scan = LogicalPlanBuilder::scan_csv( store, diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 8daefc904b7c1..79161a315a47c 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -26,14 +26,10 @@ use crate::serde::{from_proto_binary_op, proto_error, protobuf}; use crate::{convert_box_required, convert_required}; use chrono::{TimeZone, Utc}; -use datafusion::catalog::catalog::{CatalogList, MemoryCatalogList}; use datafusion::datasource::object_store::local::LocalFileSystem; -use datafusion::datasource::object_store::{FileMeta, ObjectStoreRegistry, SizedFile}; +use datafusion::datasource::object_store::{FileMeta, SizedFile}; use datafusion::datasource::PartitionedFile; -use datafusion::execution::context::{ - ExecutionConfig, ExecutionContextState, ExecutionProps, -}; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::SessionState; use datafusion::physical_plan::file_format::FileScanConfig; @@ -157,22 +153,12 @@ impl TryFrom<&protobuf::PhysicalExprNode> for Arc { .map(|x| x.try_into()) .collect::, _>>()?; - let catalog_list = - Arc::new(MemoryCatalogList::new()) as Arc; - - let ctx_state = ExecutionContextState { - catalog_list, - scalar_functions: Default::default(), - aggregate_functions: Default::default(), - config: ExecutionConfig::new(), - execution_props: ExecutionProps::new(), - object_store_registry: Arc::new(ObjectStoreRegistry::new()), - runtime_env: Arc::new(RuntimeEnv::default()), - }; + // TODO Do not create new the SessionState + let session_state = SessionState::new(); let fun_expr = functions::create_physical_fun( &(&scalar_function).into(), - &ctx_state.execution_props, + &session_state.execution_props, )?; Arc::new(ScalarFunctionExpr::new( diff --git a/ballista/rust/core/src/serde/physical_plan/mod.rs b/ballista/rust/core/src/serde/physical_plan/mod.rs index 83607ae6e5551..9c151229dbc3d 100644 --- a/ballista/rust/core/src/serde/physical_plan/mod.rs +++ b/ballista/rust/core/src/serde/physical_plan/mod.rs @@ -56,7 +56,7 @@ use datafusion::physical_plan::windows::{create_window_expr, WindowAggExec}; use datafusion::physical_plan::{ AggregateExpr, ExecutionPlan, Partitioning, PhysicalExpr, WindowExpr, }; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use prost::bytes::BufMut; use prost::Message; use std::convert::TryInto; @@ -87,7 +87,7 @@ impl AsExecutionPlan for PhysicalPlanNode { fn try_into_physical_plan( &self, - ctx: &ExecutionContext, + ctx: &SessionContext, extension_codec: &dyn PhysicalExtensionCodec, ) -> Result, BallistaError> { let plan = self.physical_plan_type.as_ref().ok_or_else(|| { @@ -883,7 +883,7 @@ impl AsExecutionPlan for PhysicalPlanNode { fn decode_scan_config( proto: &protobuf::FileScanExecConf, - ctx: &ExecutionContext, + ctx: &SessionContext, ) -> Result { let schema = Arc::new(convert_required!(proto.schema)?); let projection = proto @@ -940,7 +940,7 @@ mod roundtrip_tests { use datafusion::datasource::object_store::local::LocalFileSystem; use datafusion::datasource::PartitionedFile; use datafusion::physical_plan::sorts::sort::SortExec; - use datafusion::prelude::ExecutionContext; + use datafusion::prelude::SessionContext; use datafusion::{ arrow::{ compute::kernels::sort::SortOptions, @@ -969,7 +969,7 @@ mod roundtrip_tests { use crate::serde::protobuf::{LogicalPlanNode, PhysicalPlanNode}; fn roundtrip_test(exec_plan: Arc) -> Result<()> { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let codec: BallistaCodec = BallistaCodec::default(); let proto: protobuf::PhysicalPlanNode = diff --git a/ballista/rust/core/src/utils.rs b/ballista/rust/core/src/utils.rs index 560d459977ddd..a668b7334cfd5 100644 --- a/ballista/rust/core/src/utils.rs +++ b/ballista/rust/core/src/utils.rs @@ -38,7 +38,7 @@ use datafusion::arrow::{ }; use datafusion::error::DataFusionError; use datafusion::execution::context::{ - ExecutionConfig, ExecutionContext, ExecutionContextState, QueryPlanner, + QueryPlanner, SessionConfig, SessionContext, SessionState, }; use datafusion::logical_plan::LogicalPlan; @@ -230,15 +230,15 @@ pub fn create_df_ctx_with_ballista_query_planner( scheduler_host: &str, scheduler_port: u16, config: &BallistaConfig, -) -> ExecutionContext { +) -> SessionContext { let scheduler_url = format!("http://{}:{}", scheduler_host, scheduler_port); let planner: Arc> = Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone())); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_query_planner(planner) .with_target_partitions(config.default_shuffle_partitions()) .with_information_schema(true); - ExecutionContext::with_config(config) + SessionContext::with_config(config) } pub struct BallistaQueryPlanner { @@ -291,7 +291,7 @@ impl QueryPlanner for BallistaQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> std::result::Result, DataFusionError> { match logical_plan { LogicalPlan::CreateExternalTable(_) => { diff --git a/ballista/rust/executor/src/collect.rs b/ballista/rust/executor/src/collect.rs index 37a7f7bb0d1bd..50215f6c99a88 100644 --- a/ballista/rust/executor/src/collect.rs +++ b/ballista/rust/executor/src/collect.rs @@ -27,7 +27,7 @@ use datafusion::arrow::{ datatypes::SchemaRef, error::Result as ArrowResult, record_batch::RecordBatch, }; use datafusion::error::DataFusionError; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -81,12 +81,12 @@ impl ExecutionPlan for CollectExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { assert_eq!(0, partition); let num_partitions = self.plan.output_partitioning().partition_count(); - let futures = (0..num_partitions).map(|i| self.plan.execute(i, runtime.clone())); + let futures = (0..num_partitions).map(|i| self.plan.execute(i, context.clone())); let streams = futures::future::join_all(futures) .await .into_iter() diff --git a/ballista/rust/executor/src/execution_loop.rs b/ballista/rust/executor/src/execution_loop.rs index ddb2c972a70ff..93e18841720a5 100644 --- a/ballista/rust/executor/src/execution_loop.rs +++ b/ballista/rust/executor/src/execution_loop.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::mpsc::{Receiver, Sender, TryRecvError}; use std::{sync::Arc, time::Duration}; @@ -34,6 +35,7 @@ use ballista_core::error::BallistaError; use ballista_core::serde::physical_plan::from_proto::parse_protobuf_hash_partitioning; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::execution::context::TaskContext; pub async fn poll_loop( mut scheduler: SchedulerGrpcClient, @@ -124,6 +126,20 @@ async fn run_received_tasks = U::try_decode(task.plan.as_slice()).and_then(|proto| { proto.try_into_physical_plan( @@ -142,6 +158,7 @@ async fn run_received_tasks, + /// DataFusion session context + pub ctx: Arc, } impl Executor { @@ -46,7 +46,7 @@ impl Executor { pub fn new( metadata: ExecutorRegistration, work_dir: &str, - ctx: Arc, + ctx: Arc, ) -> Self { Self { metadata, @@ -66,6 +66,7 @@ impl Executor { stage_id: usize, part: usize, plan: Arc, + task_ctx: Arc, _shuffle_output_partitioning: Option, ) -> Result, BallistaError> { let exec = if let Some(shuffle_writer) = @@ -86,10 +87,7 @@ impl Executor { )) }?; - let config = ExecutionConfig::new().with_temp_file_path(self.work_dir.clone()); - let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); - - let partitions = exec.execute_shuffle_write(part, runtime).await?; + let partitions = exec.execute_shuffle_write(part, task_ctx).await?; println!( "=== [{}/{}/{}] Physical plan with metrics ===\n{}\n", diff --git a/ballista/rust/executor/src/executor_server.rs b/ballista/rust/executor/src/executor_server.rs index ad34634265bba..74273999e3296 100644 --- a/ballista/rust/executor/src/executor_server.rs +++ b/ballista/rust/executor/src/executor_server.rs @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc; @@ -36,6 +37,7 @@ use ballista_core::serde::protobuf::{ }; use ballista_core::serde::scheduler::ExecutorState; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; +use datafusion::execution::context::TaskContext; use datafusion::physical_plan::ExecutionPlan; use crate::as_task_status; @@ -177,6 +179,20 @@ impl ExecutorServer = @@ -197,6 +213,7 @@ impl ExecutorServer Result<()> { let executor = Arc::new(Executor::new( executor_meta, &work_dir, - Arc::new(ExecutionContext::new()), + Arc::new(SessionContext::new()), )); let scheduler = SchedulerGrpcClient::connect(scheduler_url) diff --git a/ballista/rust/executor/src/standalone.rs b/ballista/rust/executor/src/standalone.rs index 0bc2503e9dfce..dc55e0180fcf1 100644 --- a/ballista/rust/executor/src/standalone.rs +++ b/ballista/rust/executor/src/standalone.rs @@ -27,7 +27,7 @@ use ballista_core::{ serde::protobuf::{scheduler_grpc_client::SchedulerGrpcClient, ExecutorRegistration}, BALLISTA_VERSION, }; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use log::info; use tempfile::TempDir; use tokio::net::TcpListener; @@ -71,7 +71,7 @@ pub async fn new_standalone_executor< .into_string() .unwrap(); info!("work_dir: {}", work_dir); - let ctx = Arc::new(ExecutionContext::new()); + let ctx = Arc::new(SessionContext::new()); let executor = Arc::new(Executor::new(executor_meta, &work_dir, ctx)); let service = BallistaFlightService::new(executor.clone()); diff --git a/ballista/rust/scheduler/src/main.rs b/ballista/rust/scheduler/src/main.rs index 4b74573cc3ae9..f35650cae1201 100644 --- a/ballista/rust/scheduler/src/main.rs +++ b/ballista/rust/scheduler/src/main.rs @@ -62,7 +62,7 @@ mod config { } use config::prelude::*; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; async fn start_server( config_backend: Arc, @@ -85,13 +85,13 @@ async fn start_server( config_backend.clone(), namespace.clone(), policy, - Arc::new(RwLock::new(ExecutionContext::new())), + Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ), _ => SchedulerServer::new( config_backend.clone(), namespace.clone(), - Arc::new(RwLock::new(ExecutionContext::new())), + Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ), }; diff --git a/ballista/rust/scheduler/src/planner.rs b/ballista/rust/scheduler/src/planner.rs index 68f26e9ffa1d6..b18b213499ea5 100644 --- a/ballista/rust/scheduler/src/planner.rs +++ b/ballista/rust/scheduler/src/planner.rs @@ -259,7 +259,7 @@ mod test { coalesce_partitions::CoalescePartitionsExec, projection::ProjectionExec, }; use datafusion::physical_plan::{displayable, ExecutionPlan}; - use datafusion::prelude::ExecutionContext; + use datafusion::prelude::SessionContext; use ballista_core::serde::protobuf::{LogicalPlanNode, PhysicalPlanNode}; use std::sync::Arc; @@ -574,7 +574,7 @@ order by fn roundtrip_operator( plan: Arc, ) -> Result, BallistaError> { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let codec: BallistaCodec = BallistaCodec::default(); let proto: protobuf::PhysicalPlanNode = diff --git a/ballista/rust/scheduler/src/scheduler_server/grpc.rs b/ballista/rust/scheduler/src/scheduler_server/grpc.rs index 6b96d41e990c0..aaa98fb7207cf 100644 --- a/ballista/rust/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/rust/scheduler/src/scheduler_server/grpc.rs @@ -490,7 +490,7 @@ mod test { }; use ballista_core::serde::scheduler::ExecutorSpecification; use ballista_core::serde::BallistaCodec; - use datafusion::prelude::ExecutionContext; + use datafusion::prelude::SessionContext; use super::{SchedulerGrpc, SchedulerServer}; @@ -502,7 +502,7 @@ mod test { SchedulerServer::new( state_storage.clone(), namespace.to_owned(), - Arc::new(RwLock::new(ExecutionContext::new())), + Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ); let exec_meta = ExecutorRegistration { diff --git a/ballista/rust/scheduler/src/scheduler_server/mod.rs b/ballista/rust/scheduler/src/scheduler_server/mod.rs index 9106df7f7aee6..51f6fe4180b41 100644 --- a/ballista/rust/scheduler/src/scheduler_server/mod.rs +++ b/ballista/rust/scheduler/src/scheduler_server/mod.rs @@ -28,7 +28,7 @@ use ballista_core::event_loop::EventLoop; use ballista_core::serde::protobuf::executor_grpc_client::ExecutorGrpcClient; use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan, BallistaCodec}; -use datafusion::prelude::{ExecutionConfig, ExecutionContext}; +use datafusion::prelude::{SessionConfig, SessionContext}; use crate::scheduler_server::event_loop::{ SchedulerServerEvent, SchedulerServerEventAction, @@ -60,7 +60,7 @@ pub struct SchedulerServer, event_loop: Option>, query_stage_event_loop: EventLoop, - ctx: Arc>, + ctx: Arc>, codec: BallistaCodec, } @@ -68,7 +68,7 @@ impl SchedulerServer, namespace: String, - ctx: Arc>, + ctx: Arc>, codec: BallistaCodec, ) -> Self { SchedulerServer::new_with_policy( @@ -84,7 +84,7 @@ impl SchedulerServer, namespace: String, policy: TaskSchedulingPolicy, - ctx: Arc>, + ctx: Arc>, codec: BallistaCodec, ) -> Self { let state = Arc::new(SchedulerState::new(config, namespace, codec.clone())); @@ -161,8 +161,8 @@ impl SchedulerServer ExecutionContext { - let config = ExecutionConfig::new() - .with_target_partitions(config.default_shuffle_partitions()); - ExecutionContext::with_config(config) +pub fn create_datafusion_context(config: &BallistaConfig) -> SessionContext { + let config = + SessionConfig::new().with_target_partitions(config.default_shuffle_partitions()); + SessionContext::with_config(config) } diff --git a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs index 31a0f9d3127e8..52af5484c8be3 100644 --- a/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs +++ b/ballista/rust/scheduler/src/scheduler_server/query_stage_scheduler.rs @@ -30,7 +30,7 @@ use ballista_core::serde::protobuf::{ use ballista_core::serde::{AsExecutionPlan, AsLogicalPlan}; use datafusion::logical_plan::LogicalPlan; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use crate::planner::DistributedPlanner; use crate::scheduler_server::event_loop::SchedulerServerEvent; @@ -45,14 +45,14 @@ pub(crate) struct QueryStageScheduler< T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan, > { - ctx: Arc>, + ctx: Arc>, state: Arc>, event_sender: Option>, } impl QueryStageScheduler { pub(crate) fn new( - ctx: Arc>, + ctx: Arc>, state: Arc>, event_sender: Option>, ) -> Self { diff --git a/ballista/rust/scheduler/src/standalone.rs b/ballista/rust/scheduler/src/standalone.rs index 45984b6b54766..b5404e713525a 100644 --- a/ballista/rust/scheduler/src/standalone.rs +++ b/ballista/rust/scheduler/src/standalone.rs @@ -21,7 +21,7 @@ use ballista_core::{ error::Result, serde::protobuf::scheduler_grpc_server::SchedulerGrpcServer, BALLISTA_VERSION, }; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use log::info; use std::{net::SocketAddr, sync::Arc}; use tokio::net::TcpListener; @@ -39,7 +39,7 @@ pub async fn new_standalone_scheduler() -> Result { SchedulerServer::new( Arc::new(client), "ballista".to_string(), - Arc::new(RwLock::new(ExecutionContext::new())), + Arc::new(RwLock::new(SessionContext::new())), BallistaCodec::default(), ); scheduler_server.init().await?; diff --git a/ballista/rust/scheduler/src/state/mod.rs b/ballista/rust/scheduler/src/state/mod.rs index cc7252f680005..b673f7a9ea8ac 100644 --- a/ballista/rust/scheduler/src/state/mod.rs +++ b/ballista/rust/scheduler/src/state/mod.rs @@ -35,7 +35,7 @@ use ballista_core::serde::scheduler::{ ExecutorData, ExecutorDataChange, ExecutorMetadata, PartitionId, PartitionStats, }; use ballista_core::serde::{protobuf, AsExecutionPlan, AsLogicalPlan, BallistaCodec}; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use super::planner::remove_unresolved_shuffles; @@ -134,7 +134,7 @@ impl SchedulerState Result<()> { + pub async fn init(&self, ctx: &SessionContext) -> Result<()> { self.persistent_state.init(ctx).await?; Ok(()) diff --git a/ballista/rust/scheduler/src/state/persistent_state.rs b/ballista/rust/scheduler/src/state/persistent_state.rs index 5c3417464996c..ac4ef2f55930c 100644 --- a/ballista/rust/scheduler/src/state/persistent_state.rs +++ b/ballista/rust/scheduler/src/state/persistent_state.rs @@ -30,7 +30,7 @@ use crate::state::backend::StateBackendClient; use ballista_core::serde::scheduler::ExecutorMetadata; use ballista_core::serde::{protobuf, AsExecutionPlan, AsLogicalPlan, BallistaCodec}; use datafusion::physical_plan::ExecutionPlan; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; type StageKey = (String, u32); @@ -70,7 +70,7 @@ impl } /// Load the state stored in storage into memory - pub(crate) async fn init(&self, ctx: &ExecutionContext) -> Result<()> { + pub(crate) async fn init(&self, ctx: &SessionContext) -> Result<()> { self.init_executors_metadata_from_storage().await?; self.init_jobs_from_storage().await?; self.init_stages_from_storage(ctx).await?; @@ -111,7 +111,7 @@ impl Ok(()) } - async fn init_stages_from_storage(&self, ctx: &ExecutionContext) -> Result<()> { + async fn init_stages_from_storage(&self, ctx: &SessionContext) -> Result<()> { let entries = self .config_client .get_from_prefix(&get_stage_prefix(&self.namespace)) diff --git a/ballista/rust/scheduler/src/test_utils.rs b/ballista/rust/scheduler/src/test_utils.rs index b9d7ee42f48b1..9d6e83fec89e8 100644 --- a/ballista/rust/scheduler/src/test_utils.rs +++ b/ballista/rust/scheduler/src/test_utils.rs @@ -18,18 +18,17 @@ use ballista_core::error::Result; use datafusion::arrow::datatypes::{DataType, Field, Schema}; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext}; use datafusion::prelude::CsvReadOptions; pub const TPCH_TABLES: &[&str] = &[ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region", ]; -pub async fn datafusion_test_context(path: &str) -> Result { +pub async fn datafusion_test_context(path: &str) -> Result { let default_shuffle_partitions = 2; - let config = - ExecutionConfig::new().with_target_partitions(default_shuffle_partitions); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(default_shuffle_partitions); + let mut ctx = SessionContext::with_config(config); for table in TPCH_TABLES { let schema = get_tpch_schema(table); let options = CsvReadOptions::new() diff --git a/benchmarks/src/bin/nyctaxi.rs b/benchmarks/src/bin/nyctaxi.rs index 49679f46d7eba..a0cdb748a31e3 100644 --- a/benchmarks/src/bin/nyctaxi.rs +++ b/benchmarks/src/bin/nyctaxi.rs @@ -26,7 +26,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema}; use datafusion::arrow::util::pretty; use datafusion::error::Result; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext}; use datafusion::physical_plan::collect; use datafusion::prelude::CsvReadOptions; @@ -69,10 +69,10 @@ async fn main() -> Result<()> { let opt = Opt::from_args(); println!("Running benchmarks with the following options: {:?}", opt); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(opt.partitions) .with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); + let mut ctx = SessionContext::with_config(config); let path = opt.path.to_str().unwrap(); @@ -93,7 +93,7 @@ async fn main() -> Result<()> { } async fn datafusion_sql_benchmarks( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, iterations: usize, debug: bool, ) -> Result<()> { @@ -115,15 +115,15 @@ async fn datafusion_sql_benchmarks( Ok(()) } -async fn execute_sql(ctx: &mut ExecutionContext, sql: &str, debug: bool) -> Result<()> { - let runtime = ctx.state.lock().runtime_env.clone(); +async fn execute_sql(ctx: &SessionContext, sql: &str, debug: bool) -> Result<()> { let plan = ctx.create_logical_plan(sql)?; let plan = ctx.optimize(&plan)?; if debug { println!("Optimized logical plan:\n{:?}", plan); } let physical_plan = ctx.create_physical_plan(&plan).await?; - let result = collect(physical_plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let result = collect(physical_plan, task_ctx).await?; if debug { pretty::print_batches(&result)?; } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 1cc6687891104..84fd432b35eea 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -273,11 +273,10 @@ async fn main() -> Result<()> { async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result> { println!("Running benchmarks with the following options: {:?}", opt); let mut benchmark_run = BenchmarkRun::new(opt.query); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(opt.partitions) .with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); - let runtime = ctx.state.lock().runtime_env.clone(); + let mut ctx = SessionContext::with_config(config); // register tables for table in TABLES { @@ -290,10 +289,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result Result Result { } } -fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result { +fn create_logical_plan(ctx: &mut SessionContext, query: usize) -> Result { let sql = get_query_sql(query)?; ctx.create_logical_plan(&sql) } async fn execute_query( - ctx: &mut ExecutionContext, + ctx: &SessionContext, plan: &LogicalPlan, debug: bool, ) -> Result> { @@ -608,8 +606,8 @@ async fn execute_query( displayable(physical_plan.as_ref()).indent() ); } - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(physical_plan.clone(), runtime).await?; + let task_ctx = ctx.task_ctx(); + let result = collect(physical_plan.clone(), task_ctx).await?; if debug { println!( "=== Physical plan with metrics ===\n{}\n", @@ -632,8 +630,8 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> { .delimiter(b'|') .file_extension(".tbl"); - let config = ExecutionConfig::new().with_batch_size(opt.batch_size); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_batch_size(opt.batch_size); + let mut ctx = SessionContext::with_config(config); // build plan to read the TBL file let mut csv = ctx.read_csv(&input_path, options).await?; @@ -1282,10 +1280,10 @@ mod tests { async fn run_query(n: usize) -> Result<()> { // Tests running query with empty tables, to see whether they run succesfully. - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = ExecutionContext::with_config(config); + let mut ctx = SessionContext::with_config(config); for &table in TABLES { let schema = get_schema(table); @@ -1297,7 +1295,7 @@ mod tests { } let plan = create_logical_plan(&mut ctx, n)?; - execute_query(&mut ctx, &plan, false).await?; + execute_query(&ctx, &plan, false).await?; Ok(()) } @@ -1307,7 +1305,7 @@ mod tests { // load expected answers from tpch-dbgen // read csv as all strings, trim and cast to expected type as the csv string // to value parser does not handle data with leading/trailing spaces - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = string_schema(get_answer_schema(n)); let options = CsvReadOptions::new() .schema(&schema) @@ -1379,10 +1377,10 @@ mod tests { use datafusion::physical_plan::ExecutionPlan; async fn round_trip_query(n: usize) -> Result<()> { - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_target_partitions(1) .with_batch_size(10); - let mut ctx = ExecutionContext::with_config(config); + let mut ctx = SessionContext::with_config(config); let codec: BallistaCodec< protobuf::LogicalPlanNode, protobuf::PhysicalPlanNode, diff --git a/datafusion-cli/src/context.rs b/datafusion-cli/src/context.rs index 4f29af9354000..dc609bc68ea81 100644 --- a/datafusion-cli/src/context.rs +++ b/datafusion-cli/src/context.rs @@ -19,13 +19,13 @@ use datafusion::dataframe::DataFrame; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::context::{ExecutionConfig, ExecutionContext}; +use datafusion::execution::context::{SessionConfig, SessionContext}; use std::sync::Arc; /// The CLI supports using a local DataFusion context or a distributed BallistaContext pub enum Context { /// In-process execution with DataFusion - Local(ExecutionContext), + Local(SessionContext), /// Distributed execution with Ballista (if available) Remote(BallistaContext), } @@ -37,8 +37,8 @@ impl Context { } /// create a local context using the given config - pub fn new_local(config: &ExecutionConfig) -> Context { - Context::Local(ExecutionContext::with_config(config.clone())) + pub fn new_local(config: &SessionConfig) -> Context { + Context::Local(SessionContext::with_config(config.clone())) } /// execute an SQL statement against the context diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs index 08878f9c70eb6..92b997d41e0d8 100644 --- a/datafusion-cli/src/main.rs +++ b/datafusion-cli/src/main.rs @@ -17,7 +17,7 @@ use clap::Parser; use datafusion::error::Result; -use datafusion::execution::context::ExecutionConfig; +use datafusion::execution::context::SessionConfig; use datafusion_cli::{ context::Context, exec, print_format::PrintFormat, print_options::PrintOptions, DATAFUSION_CLI_VERSION, @@ -98,15 +98,15 @@ pub async fn main() -> Result<()> { env::set_current_dir(&p).unwrap(); }; - let mut execution_config = ExecutionConfig::new().with_information_schema(true); + let mut session_config = SessionConfig::new().with_information_schema(true); if let Some(batch_size) = args.batch_size { - execution_config = execution_config.with_batch_size(batch_size); + session_config = session_config.with_batch_size(batch_size); }; let mut ctx: Context = match (args.host, args.port) { (Some(ref h), Some(p)) => Context::new_remote(h, p)?, - _ => Context::new_local(&execution_config), + _ => Context::new_local(&session_config), }; let mut print_options = PrintOptions { diff --git a/datafusion-examples/examples/avro_sql.rs b/datafusion-examples/examples/avro_sql.rs index f08c12bbb73a6..dd11fb4f9b89b 100644 --- a/datafusion-examples/examples/avro_sql.rs +++ b/datafusion-examples/examples/avro_sql.rs @@ -25,7 +25,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/csv_sql.rs b/datafusion-examples/examples/csv_sql.rs index 5ad9bd7d4385f..59c6fa072e4e0 100644 --- a/datafusion-examples/examples/csv_sql.rs +++ b/datafusion-examples/examples/csv_sql.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index b3ef04d721561..3d725c4c61afd 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -22,7 +22,7 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::dataframe::DataFrame; use datafusion::datasource::TableProvider; use datafusion::error::{DataFusionError, Result}; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; use datafusion::logical_plan::{Expr, LogicalPlanBuilder}; use datafusion::physical_plan::expressions::PhysicalSortExpr; use datafusion::physical_plan::memory::MemoryStream; @@ -57,7 +57,7 @@ async fn search_accounts( expected_result_length: usize, ) -> Result<()> { // create local execution context - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // create logical plan composed of a single TableScan let logical_plan = @@ -235,7 +235,7 @@ impl ExecutionPlan for CustomExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let users: Vec = { let db = self.db.inner.lock().unwrap(); diff --git a/datafusion-examples/examples/dataframe.rs b/datafusion-examples/examples/dataframe.rs index 6fd34610ba5c9..3df43c51bc915 100644 --- a/datafusion-examples/examples/dataframe.rs +++ b/datafusion-examples/examples/dataframe.rs @@ -23,7 +23,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs index e17c69ed1ded4..67504e9b5baa6 100644 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ b/datafusion-examples/examples/dataframe_in_memory.rs @@ -44,7 +44,7 @@ async fn main() -> Result<()> { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion-examples/examples/flight_server.rs b/datafusion-examples/examples/flight_server.rs index c26dcce59f69d..04cb49fe7c22c 100644 --- a/datafusion-examples/examples/flight_server.rs +++ b/datafusion-examples/examples/flight_server.rs @@ -90,7 +90,7 @@ impl FlightService for FlightServiceImpl { println!("do_get: {}", sql); // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/memtable.rs b/datafusion-examples/examples/memtable.rs index e113d98db6774..ccae8c5005427 100644 --- a/datafusion-examples/examples/memtable.rs +++ b/datafusion-examples/examples/memtable.rs @@ -20,7 +20,7 @@ use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use std::sync::Arc; use std::time::Duration; use tokio::time::timeout; @@ -31,7 +31,7 @@ async fn main() -> Result<()> { let mem_table = create_memtable()?; // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // Register the in-memory table containing the data ctx.register_table("users", Arc::new(mem_table))?; diff --git a/datafusion-examples/examples/parquet_sql.rs b/datafusion-examples/examples/parquet_sql.rs index e74ed39c68ce2..39deb633b7a5d 100644 --- a/datafusion-examples/examples/parquet_sql.rs +++ b/datafusion-examples/examples/parquet_sql.rs @@ -22,8 +22,8 @@ use datafusion::prelude::*; /// fetching results #[tokio::main] async fn main() -> Result<()> { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let mut ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/parquet_sql_multiple_files.rs b/datafusion-examples/examples/parquet_sql_multiple_files.rs index 7485bc72f1931..4f001eeba8f99 100644 --- a/datafusion-examples/examples/parquet_sql_multiple_files.rs +++ b/datafusion-examples/examples/parquet_sql_multiple_files.rs @@ -28,7 +28,7 @@ use std::sync::Arc; #[tokio::main] async fn main() -> Result<()> { // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::arrow::util::test_util::parquet_test_data(); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index 3acace27e4de3..2b69e8d1e651a 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -28,8 +28,8 @@ use datafusion::{error::Result, logical_plan::create_udaf, physical_plan::Accumu use datafusion::{prelude::*, scalar::ScalarValue}; use std::sync::Arc; -// create local execution context with an in-memory table -fn create_context() -> Result { +// create local session context with an in-memory table +fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. @@ -46,7 +46,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?; diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index 33242c7b9870b..70ad04e7e6957 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -30,7 +30,7 @@ use datafusion::{error::Result, physical_plan::functions::make_scalar_function}; use std::sync::Arc; // create local execution context with an in-memory table -fn create_context() -> Result { +fn create_context() -> Result { use datafusion::arrow::datatypes::{Field, Schema}; use datafusion::datasource::MemTable; // define a schema. @@ -49,7 +49,7 @@ fn create_context() -> Result { )?; // declare a new context. In spark API, this corresponds to a new spark SQLsession - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml index 80842272d6135..e2523b2b93347 100644 --- a/datafusion/Cargo.toml +++ b/datafusion/Cargo.toml @@ -81,6 +81,7 @@ num-traits = { version = "0.2", optional = true } pyo3 = { version = "0.16", optional = true } tempfile = "3" parking_lot = "0.12" +uuid = { version = "0.8", features = ["v4"] } [dev-dependencies] criterion = "0.3" diff --git a/datafusion/benches/aggregate_query_sql.rs b/datafusion/benches/aggregate_query_sql.rs index e587fe58cd44d..78e914cb39849 100644 --- a/datafusion/benches/aggregate_query_sql.rs +++ b/datafusion/benches/aggregate_query_sql.rs @@ -24,12 +24,12 @@ mod data_utils; use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); @@ -39,8 +39,8 @@ fn create_context( partitions_len: usize, array_len: usize, batch_size: usize, -) -> Result>> { - let mut ctx = ExecutionContext::new(); +) -> Result>> { + let mut ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/benches/filter_query_sql.rs b/datafusion/benches/filter_query_sql.rs index 9885918de2296..e3401b11a9035 100644 --- a/datafusion/benches/filter_query_sql.rs +++ b/datafusion/benches/filter_query_sql.rs @@ -22,13 +22,13 @@ use arrow::{ }; use criterion::{criterion_group, criterion_main, Criterion}; use datafusion::from_slice::FromSlice; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use datafusion::{datasource::MemTable, error::Result}; use futures::executor::block_on; use std::sync::Arc; use tokio::runtime::Runtime; -async fn query(ctx: &mut ExecutionContext, sql: &str) { +async fn query(ctx: &mut SessionContext, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -36,7 +36,7 @@ async fn query(ctx: &mut ExecutionContext, sql: &str) { criterion::black_box(rt.block_on(df.collect()).unwrap()); } -fn create_context(array_len: usize, batch_size: usize) -> Result { +fn create_context(array_len: usize, batch_size: usize) -> Result { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), @@ -57,7 +57,7 @@ fn create_context(array_len: usize, batch_size: usize) -> Result>(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; diff --git a/datafusion/benches/math_query_sql.rs b/datafusion/benches/math_query_sql.rs index 6195937dc4e51..88d486a5dcc59 100644 --- a/datafusion/benches/math_query_sql.rs +++ b/datafusion/benches/math_query_sql.rs @@ -34,10 +34,10 @@ use arrow::{ }; use datafusion::datasource::MemTable; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use datafusion::from_slice::FromSlice; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -48,7 +48,7 @@ fn query(ctx: Arc>, sql: &str) { fn create_context( array_len: usize, batch_size: usize, -) -> Result>> { +) -> Result>> { // define a schema. let schema = Arc::new(Schema::new(vec![ Field::new("f32", DataType::Float32, false), @@ -69,7 +69,7 @@ fn create_context( }) .collect::>(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // declare a table in memory. In spark API, this corresponds to createDataFrame(...). let provider = MemTable::try_new(schema, vec![batches])?; diff --git a/datafusion/benches/parquet_query_sql.rs b/datafusion/benches/parquet_query_sql.rs index 17bc78bd038a7..183f647390698 100644 --- a/datafusion/benches/parquet_query_sql.rs +++ b/datafusion/benches/parquet_query_sql.rs @@ -24,7 +24,7 @@ use arrow::datatypes::{ }; use arrow::record_batch::RecordBatch; use criterion::{criterion_group, criterion_main, Criterion}; -use datafusion::prelude::ExecutionContext; +use datafusion::prelude::SessionContext; use parquet::arrow::ArrowWriter; use parquet::file::properties::{WriterProperties, WriterVersion}; use rand::distributions::uniform::SampleUniform; @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { assert!(Path::new(&file_path).exists(), "path not found"); println!("Using parquet file {}", file_path); - let mut context = ExecutionContext::new(); + let mut context = SessionContext::new(); let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap(); rt.block_on(context.register_parquet("t", file_path.as_str())) diff --git a/datafusion/benches/physical_plan.rs b/datafusion/benches/physical_plan.rs index 8dd1f49d183e3..ebf36ec56e5e6 100644 --- a/datafusion/benches/physical_plan.rs +++ b/datafusion/benches/physical_plan.rs @@ -29,17 +29,21 @@ use arrow::{ }; use tokio::runtime::Runtime; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::physical_plan::sorts::sort_preserving_merge::SortPreservingMergeExec; use datafusion::physical_plan::{ collect, expressions::{col, PhysicalSortExpr}, memory::MemoryExec, }; +use datafusion::prelude::SessionContext; // Initialise the operator using the provided record batches and the sort key // as inputs. All record batches must have the same schema. -fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { +fn sort_preserving_merge_operator( + session_ctx: Arc, + batches: Vec, + sort: &[&str], +) { let schema = batches[0].schema(); let sort = sort @@ -57,10 +61,9 @@ fn sort_preserving_merge_operator(batches: Vec, sort: &[&str]) { ) .unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - + let task_ctx = session_ctx.task_ctx(); let rt = Runtime::new().unwrap(); - let rt_env = Arc::new(RuntimeEnv::default()); - rt.block_on(collect(merge, rt_env)).unwrap(); + rt.block_on(collect(merge, task_ctx)).unwrap(); } // Produces `n` record batches of row size `m`. Each record batch will have @@ -161,12 +164,18 @@ fn criterion_benchmark(c: &mut Criterion) { ), ]; + let ctx = Arc::new(SessionContext::new()); for (name, input) in benches { + let ctx_clone = ctx.clone(); c.bench_function(name, move |b| { b.iter_batched( || input.clone(), |input| { - sort_preserving_merge_operator(input, &["a", "b", "c", "d"]); + sort_preserving_merge_operator( + ctx_clone.clone(), + input, + &["a", "b", "c", "d"], + ); }, BatchSize::LargeInput, ) diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs index 2434341ae51c0..097191e0f02b1 100644 --- a/datafusion/benches/sort_limit_query_sql.rs +++ b/datafusion/benches/sort_limit_query_sql.rs @@ -31,11 +31,11 @@ extern crate datafusion; use arrow::datatypes::{DataType, Field, Schema}; use datafusion::datasource::MemTable; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); // execute the query @@ -43,7 +43,7 @@ fn query(ctx: Arc>, sql: &str) { rt.block_on(df.collect()).unwrap(); } -fn create_context() -> Arc> { +fn create_context() -> Arc> { // define schema for data source (csv file) let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Utf8, false), @@ -76,18 +76,18 @@ fn create_context() -> Arc> { let rt = Runtime::new().unwrap(); - let ctx_holder: Arc>>>> = + let ctx_holder: Arc>>>> = Arc::new(Mutex::new(vec![])); let partitions = 16; rt.block_on(async { - // create local execution context - let mut ctx = ExecutionContext::new(); + // create local session context + let mut ctx = SessionContext::new(); ctx.state.lock().config.target_partitions = 1; - let runtime = ctx.state.lock().runtime_env.clone(); - let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), runtime) + let task_ctx = ctx.task_ctx(); + let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx) .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) diff --git a/datafusion/benches/window_query_sql.rs b/datafusion/benches/window_query_sql.rs index dad838eb7f628..f76958bbbd32a 100644 --- a/datafusion/benches/window_query_sql.rs +++ b/datafusion/benches/window_query_sql.rs @@ -24,12 +24,12 @@ mod data_utils; use crate::criterion::Criterion; use data_utils::create_table_provider; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use parking_lot::Mutex; use std::sync::Arc; use tokio::runtime::Runtime; -fn query(ctx: Arc>, sql: &str) { +fn query(ctx: Arc>, sql: &str) { let rt = Runtime::new().unwrap(); let df = rt.block_on(ctx.lock().sql(sql)).unwrap(); criterion::black_box(rt.block_on(df.collect()).unwrap()); @@ -39,8 +39,8 @@ fn create_context( partitions_len: usize, array_len: usize, batch_size: usize, -) -> Result>> { - let mut ctx = ExecutionContext::new(); +) -> Result>> { + let mut ctx = SessionContext::new(); let provider = create_table_provider(partitions_len, array_len, batch_size)?; ctx.register_table("t", provider)?; Ok(Arc::new(Mutex::new(ctx))) diff --git a/datafusion/src/catalog/schema.rs b/datafusion/src/catalog/schema.rs index a97590af216e5..f5e8b9afdd182 100644 --- a/datafusion/src/catalog/schema.rs +++ b/datafusion/src/catalog/schema.rs @@ -251,7 +251,7 @@ mod tests { }; use crate::datasource::empty::EmptyTable; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::context::ExecutionContext; + use crate::execution::context::SessionContext; use futures::StreamExt; @@ -290,7 +290,7 @@ mod tests { let catalog = MemoryCatalogProvider::new(); catalog.register_schema("active", Arc::new(schema)); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_catalog("cat", Arc::new(catalog)); diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index 7ea4fb5b6211a..643a7cea929dc 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -34,7 +34,7 @@ use crate::arrow::datatypes::SchemaRef; use crate::arrow::util::pretty; use crate::datasource::TableProvider; use crate::datasource::TableType; -use crate::execution::context::{ExecutionContext, ExecutionContextState}; +use crate::execution::context::{SessionContext, SessionState, TaskContext}; use crate::physical_plan::file_format::{plan_to_csv, plan_to_parquet}; use crate::physical_plan::{collect, collect_partitioned}; use crate::physical_plan::{execute_stream, execute_stream_partitioned, ExecutionPlan}; @@ -48,7 +48,7 @@ use std::any::Any; /// [Spark DataFrame](https://spark.apache.org/docs/latest/sql-programming-guide.html) /// /// DataFrames are typically created by the `read_csv` and `read_parquet` methods on the -/// [ExecutionContext](../execution/context/struct.ExecutionContext.html) and can then be modified +/// [SessionContext](../execution/context/struct.SessionContext.html) and can then be modified /// by calling the transformation methods, such as `filter`, `select`, `aggregate`, and `limit` /// to build up a query definition. /// @@ -59,7 +59,7 @@ use std::any::Any; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -69,23 +69,23 @@ use std::any::Any; /// # } /// ``` pub struct DataFrame { - ctx_state: Arc>, + session_state: Arc>, plan: LogicalPlan, } impl DataFrame { /// Create a new Table based on an existing logical plan - pub fn new(ctx_state: Arc>, plan: &LogicalPlan) -> Self { + pub fn new(session_state: Arc>, plan: &LogicalPlan) -> Self { Self { - ctx_state, + session_state, plan: plan.clone(), } } /// Create a physical plan pub async fn create_physical_plan(&self) -> Result> { - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); + let state = self.session_state.lock().clone(); + let ctx = SessionContext::from(Arc::new(Mutex::new(state))); let plan = ctx.optimize(&self.plan)?; ctx.create_physical_plan(&plan).await } @@ -98,7 +98,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select_columns(&["a", "b"])?; /// # Ok(()) @@ -120,7 +120,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.select(vec![col("a") * col("b"), col("c")])?; /// # Ok(()) @@ -136,7 +136,7 @@ impl DataFrame { let project_plan = LogicalPlanBuilder::from(plan).project(expr_list)?.build()?; Ok(Arc::new(DataFrame::new( - self.ctx_state.clone(), + self.session_state.clone(), &project_plan, ))) } @@ -148,7 +148,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))?; /// # Ok(()) @@ -158,7 +158,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .filter(predicate)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Perform an aggregate query with optional grouping expressions. @@ -168,7 +168,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// /// // The following use is the equivalent of "SELECT MIN(b) GROUP BY a" @@ -187,7 +187,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .aggregate(group_expr, aggr_expr)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Limit the number of rows returned from this DataFrame. @@ -197,7 +197,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.limit(100)?; /// # Ok(()) @@ -207,7 +207,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .limit(n)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Calculate the union two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema @@ -217,7 +217,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// # Ok(()) @@ -227,7 +227,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .union(dataframe.to_logical_plan())? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Calculate the union distinct two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema @@ -237,7 +237,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.union(df.clone())?; /// let df = df.distinct()?; @@ -246,7 +246,7 @@ impl DataFrame { /// ``` pub fn distinct(&self) -> Result> { Ok(Arc::new(DataFrame::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::from(self.to_logical_plan()) .distinct()? .build()?, @@ -261,7 +261,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.sort(vec![col("a").sort(true, true), col("b").sort(false, false)])?; /// # Ok(()) @@ -271,7 +271,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .sort(expr)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Join this DataFrame with another DataFrame using the specified columns as join keys @@ -281,7 +281,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let left = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let right = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await? /// .select(vec![ @@ -307,7 +307,7 @@ impl DataFrame { (left_cols.to_vec(), right_cols.to_vec()), )? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } // TODO: add join_using @@ -319,7 +319,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df1 = df.repartition(Partitioning::RoundRobinBatch(4))?; /// # Ok(()) @@ -332,7 +332,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .repartition(partitioning_scheme)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Convert the logical plan represented by this DataFrame into a physical plan and @@ -343,7 +343,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect().await?; /// # Ok(()) @@ -351,8 +351,8 @@ impl DataFrame { /// ``` pub async fn collect(&self) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(collect(plan, runtime).await?) + let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone())); + Ok(collect(plan, task_ctx).await?) } /// Print results. @@ -362,7 +362,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show().await?; /// # Ok(()) @@ -380,7 +380,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// df.show_limit(10).await?; /// # Ok(()) @@ -398,7 +398,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let stream = df.execute_stream().await?; /// # Ok(()) @@ -406,8 +406,8 @@ impl DataFrame { /// ``` pub async fn execute_stream(&self) -> Result { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - execute_stream(plan, runtime).await + let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone())); + execute_stream(plan, task_ctx).await } /// Executes this DataFrame and collects all results into a vector of vector of RecordBatch @@ -418,7 +418,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.collect_partitioned().await?; /// # Ok(()) @@ -426,8 +426,8 @@ impl DataFrame { /// ``` pub async fn collect_partitioned(&self) -> Result>> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(collect_partitioned(plan, runtime).await?) + let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone())); + Ok(collect_partitioned(plan, task_ctx).await?) } /// Executes this DataFrame and returns one stream per partition. @@ -437,7 +437,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.execute_stream_partitioned().await?; /// # Ok(()) @@ -447,8 +447,8 @@ impl DataFrame { &self, ) -> Result> { let plan = self.create_physical_plan().await?; - let runtime = self.ctx_state.lock().runtime_env.clone(); - Ok(execute_stream_partitioned(plan, runtime).await?) + let task_ctx = Arc::new(TaskContext::from(&self.session_state.lock().clone())); + Ok(execute_stream_partitioned(plan, task_ctx).await?) } /// Returns the schema describing the output of this DataFrame in terms of columns returned, @@ -459,7 +459,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let schema = df.schema(); /// # Ok(()) @@ -483,7 +483,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let batches = df.limit(100)?.explain(false, false)?.collect().await?; /// # Ok(()) @@ -493,7 +493,7 @@ impl DataFrame { let plan = LogicalPlanBuilder::from(self.to_logical_plan()) .explain(verbose, analyze)? .build()?; - Ok(Arc::new(DataFrame::new(self.ctx_state.clone(), &plan))) + Ok(Arc::new(DataFrame::new(self.session_state.clone(), &plan))) } /// Return a `FunctionRegistry` used to plan udf's calls @@ -503,7 +503,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let f = df.registry(); /// // use f.udf("name", vec![...]) to use the udf @@ -511,7 +511,7 @@ impl DataFrame { /// # } /// ``` pub fn registry(&self) -> Arc { - let registry = self.ctx_state.lock().clone(); + let registry = self.session_state.lock().clone(); Arc::new(registry) } @@ -522,7 +522,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.intersect(df.clone())?; /// # Ok(()) @@ -532,7 +532,7 @@ impl DataFrame { let left_plan = self.to_logical_plan(); let right_plan = dataframe.to_logical_plan(); Ok(Arc::new(DataFrame::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::intersect(left_plan, right_plan, true)?, ))) } @@ -544,7 +544,7 @@ impl DataFrame { /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { - /// let mut ctx = ExecutionContext::new(); + /// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.except(df.clone())?; /// # Ok(()) @@ -555,7 +555,7 @@ impl DataFrame { let right_plan = dataframe.to_logical_plan(); Ok(Arc::new(DataFrame::new( - self.ctx_state.clone(), + self.session_state.clone(), &LogicalPlanBuilder::except(left_plan, right_plan, true)?, ))) } @@ -563,8 +563,8 @@ impl DataFrame { /// Write a `DataFrame` to a CSV file. pub async fn write_csv(&self, path: &str) -> Result<()> { let plan = self.create_physical_plan().await?; - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); + let state = self.session_state.lock().clone(); + let ctx = SessionContext::from(Arc::new(Mutex::new(state))); plan_to_csv(&ctx, plan, path).await } @@ -575,8 +575,8 @@ impl DataFrame { writer_properties: Option, ) -> Result<()> { let plan = self.create_physical_plan().await?; - let state = self.ctx_state.lock().clone(); - let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); + let state = self.session_state.lock().clone(); + let ctx = SessionContext::from(Arc::new(Mutex::new(state))); plan_to_parquet(&ctx, plan, path, writer_properties).await } } @@ -606,7 +606,10 @@ impl TableProvider for DataFrame { .as_ref() // construct projections .map_or_else( - || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>), + || { + Ok(Arc::new(Self::new(self.session_state.clone(), &self.plan)) + as Arc<_>) + }, |projection| { let schema = TableProvider::schema(self).project(projection)?; let names = schema @@ -624,7 +627,7 @@ impl TableProvider for DataFrame { ))?; // add a limit if given Self::new( - self.ctx_state.clone(), + self.session_state.clone(), &limit .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? .to_logical_plan(), @@ -641,7 +644,7 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; use crate::physical_plan::{window_functions, ColumnarValue}; - use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext}; + use crate::{assert_batches_sorted_eq, execution::context::SessionContext}; use crate::{logical_plan::*, test_util}; use arrow::datatypes::DataType; use datafusion_expr::ScalarFunctionImplementation; @@ -795,7 +798,7 @@ mod tests { #[tokio::test] async fn registry() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; // declare the udf @@ -871,7 +874,7 @@ mod tests { #[tokio::test] async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let df_impl = Arc::new(DataFrame::new(ctx.state.clone(), &df.to_logical_plan())); // register a dataframe as a table @@ -929,13 +932,13 @@ mod tests { /// Create a logical plan from a SQL query async fn create_plan(sql: &str) -> Result { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx, "aggregate_test_100").await?; ctx.create_logical_plan(sql) } async fn test_table_with_name(name: &str) -> Result> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx, name).await?; ctx.table(name) } @@ -945,7 +948,7 @@ mod tests { } async fn register_aggregate_csv( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, table_name: &str, ) -> Result<()> { let schema = test_util::aggr_test_schema(); diff --git a/datafusion/src/datasource/file_format/avro.rs b/datafusion/src/datasource/file_format/avro.rs index fa02d1ae28336..ed77faaf564e2 100644 --- a/datafusion/src/datasource/file_format/avro.rs +++ b/datafusion/src/datasource/file_format/avro.rs @@ -84,6 +84,7 @@ mod tests { use super::*; use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampMicrosecondArray, @@ -92,10 +93,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let stream = exec.execute(0, task_ctx).await?; let tt_batches = stream .map(|batch| { @@ -113,10 +116,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec("alltypes_plain.avro", &projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -126,7 +130,8 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec("alltypes_plain.avro", &projection, None).await?; @@ -153,7 +158,7 @@ mod tests { x ); - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -177,11 +182,12 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![1]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -206,11 +212,12 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -232,11 +239,12 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![10]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -258,11 +266,12 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![6]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -287,11 +296,12 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![7]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -316,11 +326,12 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_avro() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![9]); let exec = get_exec("alltypes_plain.avro", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(batches.len(), 1); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); diff --git a/datafusion/src/datasource/file_format/csv.rs b/datafusion/src/datasource/file_format/csv.rs index 6aa0d21235a43..3abe9e6482cd0 100644 --- a/datafusion/src/datasource/file_format/csv.rs +++ b/datafusion/src/datasource/file_format/csv.rs @@ -138,7 +138,7 @@ mod tests { use arrow::array::StringArray; use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::{ datasource::{ file_format::FileScanConfig, @@ -152,11 +152,13 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); // skip column 9 that overflows the automaticly discovered column type of i64 (u64 would work) let projection = Some(vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12]); let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).await?; let tt_batches: i32 = stream .map(|batch| { @@ -178,10 +180,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0, 1, 2, 3]); let exec = get_exec("aggregate_test_100.csv", &projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -224,11 +227,12 @@ mod tests { #[tokio::test] async fn read_char_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0]); let exec = get_exec("aggregate_test_100.csv", &projection, None).await?; - let batches = collect(exec, runtime).await.expect("Collect batches"); + let batches = collect(exec, task_ctx).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); diff --git a/datafusion/src/datasource/file_format/json.rs b/datafusion/src/datasource/file_format/json.rs index bdd5ef81d5592..2f5f631a5a30e 100644 --- a/datafusion/src/datasource/file_format/json.rs +++ b/datafusion/src/datasource/file_format/json.rs @@ -100,7 +100,7 @@ mod tests { use arrow::array::Int64Array; use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::prelude::{SessionConfig, SessionContext}; use crate::{ datasource::{ file_format::FileScanConfig, @@ -114,10 +114,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); let projection = None; let exec = get_exec(&projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).await?; let tt_batches: i32 = stream .map(|batch| { @@ -139,10 +141,11 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec(&projection, Some(1)).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(4, batches[0].num_columns()); assert_eq!(1, batches[0].num_rows()); @@ -168,11 +171,12 @@ mod tests { #[tokio::test] async fn read_int_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0]); let exec = get_exec(&projection, None).await?; - let batches = collect(exec, runtime).await.expect("Collect batches"); + let batches = collect(exec, task_ctx).await.expect("Collect batches"); assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs index d1d26e2c6d423..1dd6b02b325ff 100644 --- a/datafusion/src/datasource/file_format/parquet.rs +++ b/datafusion/src/datasource/file_format/parquet.rs @@ -367,7 +367,7 @@ mod tests { use super::*; - use crate::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::array::{ BinaryArray, BooleanArray, Float32Array, Float64Array, Int32Array, TimestampNanosecondArray, @@ -376,10 +376,12 @@ mod tests { #[tokio::test] async fn read_small_batches() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(2))?); + let config = SessionConfig::new().with_batch_size(2); + let ctx = SessionContext::with_config(config); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let stream = exec.execute(0, runtime).await?; + let task_ctx = ctx.task_ctx(); + let stream = exec.execute(0, task_ctx).await?; let tt_batches = stream .map(|batch| { @@ -401,7 +403,8 @@ mod tests { #[tokio::test] async fn read_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, Some(1)).await?; @@ -409,7 +412,7 @@ mod tests { assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); assert!(exec.statistics().is_exact); - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -419,7 +422,8 @@ mod tests { #[tokio::test] async fn read_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = None; let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; @@ -445,7 +449,7 @@ mod tests { y ); - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(11, batches[0].num_columns()); @@ -456,11 +460,12 @@ mod tests { #[tokio::test] async fn read_bool_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![1]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -485,11 +490,12 @@ mod tests { #[tokio::test] async fn read_i32_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![0]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -511,11 +517,12 @@ mod tests { #[tokio::test] async fn read_i96_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![10]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -537,11 +544,12 @@ mod tests { #[tokio::test] async fn read_f32_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![6]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -566,11 +574,12 @@ mod tests { #[tokio::test] async fn read_f64_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![7]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); @@ -595,11 +604,12 @@ mod tests { #[tokio::test] async fn read_binary_alltypes_plain_parquet() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let projection = Some(vec![9]); let exec = get_exec("alltypes_plain.parquet", &projection, None).await?; - let batches = collect(exec, runtime).await?; + let batches = collect(exec, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(8, batches[0].num_rows()); diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs index 58609385dd65a..92f4511845ad3 100644 --- a/datafusion/src/datasource/listing/helpers.rs +++ b/datafusion/src/datasource/listing/helpers.rs @@ -37,7 +37,7 @@ use log::debug; use crate::{ error::Result, - execution::context::ExecutionContext, + execution::context::SessionContext, logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion}, physical_plan::functions::Volatility, scalar::ScalarValue, @@ -242,7 +242,7 @@ pub async fn pruned_partition_list( // Filter the partitions using a local datafusion context // TODO having the external context would allow us to resolve `Volatility::Stable` // scalar functions (`ScalarFunction` & `ScalarUDF`) and `ScalarVariable`s - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let mut df = ctx.read_table(Arc::new(mem_table))?; for filter in applicable_filters { df = df.filter(filter.clone())?; diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs index 5fad702672efc..90e429b187ca8 100644 --- a/datafusion/src/datasource/memory.rs +++ b/datafusion/src/datasource/memory.rs @@ -29,7 +29,7 @@ use async_trait::async_trait; use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; @@ -65,7 +65,7 @@ impl MemTable { pub async fn load( t: Arc, output_partitions: Option, - runtime: Arc, + context: Arc, ) -> Result { let schema = t.schema(); let exec = t.scan(&None, &[], None).await?; @@ -73,10 +73,10 @@ impl MemTable { let tasks = (0..partition_count) .map(|part_i| { - let runtime1 = runtime.clone(); + let context1 = context.clone(); let exec = exec.clone(); tokio::spawn(async move { - let stream = exec.execute(part_i, runtime1.clone()).await?; + let stream = exec.execute(part_i, context1.clone()).await?; common::collect(stream).await }) }) @@ -103,7 +103,7 @@ impl MemTable { let mut output_partitions = vec![]; for i in 0..exec.output_partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let mut stream = exec.execute(i, context.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -145,6 +145,7 @@ impl TableProvider for MemTable { mod tests { use super::*; use crate::from_slice::FromSlice; + use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; @@ -153,7 +154,8 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -175,7 +177,7 @@ mod tests { // scan with projection let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -187,7 +189,8 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -206,7 +209,7 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; let exec = provider.scan(&None, &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); @@ -316,7 +319,8 @@ mod tests { #[tokio::test] async fn test_merged_schema() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let mut metadata = HashMap::new(); metadata.insert("foo".to_string(), "bar".to_string()); @@ -361,7 +365,7 @@ mod tests { MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; let exec = provider.scan(&None, &[], None).await?; - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); assert_eq!(3, batch1.num_columns()); diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 09fc1aef961a6..1361637270124 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -//! ExecutionContext contains methods for registering data sources and executing queries +//! SessionContext contains methods for registering data sources and executing queries use crate::{ catalog::{ catalog::{CatalogList, MemoryCatalogList}, @@ -91,6 +91,7 @@ use crate::variable::{VarProvider, VarType}; use async_trait::async_trait; use chrono::{DateTime, Utc}; use parquet::file::properties::WriterProperties; +use uuid::Uuid; use super::{ disk_manager::DiskManagerConfig, @@ -99,8 +100,9 @@ use super::{ DiskManager, MemoryManager, }; -/// ExecutionContext is the main interface for executing queries with DataFusion. The context -/// provides the following functionality: +/// SessionContext is the main interface for executing queries with DataFusion. It stands for +/// the connection between user and DataFusion/Ballista cluster. +/// The context provides the following functionality /// /// * Create DataFrame from a CSV or Parquet data source. /// * Register a CSV or Parquet data source as a table that can be referenced from a SQL query. @@ -115,7 +117,7 @@ use super::{ /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let mut ctx = SessionContext::new(); /// let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; /// let df = df.filter(col("a").lt_eq(col("b")))? /// .aggregate(vec![col("a")], vec![min(col("b"))])? @@ -133,32 +135,34 @@ use super::{ /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { -/// let mut ctx = ExecutionContext::new(); +/// let mut ctx = SessionContext::new(); /// ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; /// let results = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?; /// # Ok(()) /// # } /// ``` #[derive(Clone)] -pub struct ExecutionContext { +pub struct SessionContext { + /// Uuid for the session + session_id: String, /// Internal state for the context - pub state: Arc>, + pub state: Arc>, } -impl Default for ExecutionContext { +impl Default for SessionContext { fn default() -> Self { Self::new() } } -impl ExecutionContext { +impl SessionContext { /// Creates a new execution context using a default configuration. pub fn new() -> Self { - Self::with_config(ExecutionConfig::new()) + Self::with_config(SessionConfig::new()) } - /// Creates a new execution context using the provided configuration. - pub fn with_config(config: ExecutionConfig) -> Self { + /// Creates a new session context using the provided configuration. + pub fn with_config(config: SessionConfig) -> Self { let catalog_list = Arc::new(MemoryCatalogList::new()) as Arc; if config.create_default_catalog_and_schema { @@ -183,21 +187,23 @@ impl ExecutionContext { } let runtime_env = Arc::new(RuntimeEnv::new(config.runtime.clone()).unwrap()); - + let state = SessionState { + session_id: Uuid::new_v4().to_string(), + catalog_list, + scalar_functions: HashMap::new(), + aggregate_functions: HashMap::new(), + config, + execution_props: ExecutionProps::new(), + object_store_registry: Arc::new(ObjectStoreRegistry::new()), + runtime_env, + }; Self { - state: Arc::new(Mutex::new(ExecutionContextState { - catalog_list, - scalar_functions: HashMap::new(), - aggregate_functions: HashMap::new(), - config, - execution_props: ExecutionProps::new(), - object_store_registry: Arc::new(ObjectStoreRegistry::new()), - runtime_env, - })), + session_id: state.session_id.clone(), + state: Arc::new(Mutex::new(state)), } } - /// Return the [RuntimeEnv] used to run queries with this [ExecutionContext] + /// Return the [RuntimeEnv] used to run queries with this [SessionContext] pub fn runtime_env(&self) -> Arc { self.state.lock().runtime_env.clone() } @@ -643,9 +649,9 @@ impl ExecutionContext { /// /// Use [`table`] to get a specific table. /// - /// [`table`]: ExecutionContext::table + /// [`table`]: SessionContext::table #[deprecated( - note = "Please use the catalog provider interface (`ExecutionContext::catalog`) to examine available catalogs, schemas, and tables" + note = "Please use the catalog provider interface (`SessionContext::catalog`) to examine available catalogs, schemas, and tables" )] pub fn tables(&self) -> Result> { Ok(self @@ -753,15 +759,21 @@ impl ExecutionContext { trace!("Full Optimized logical plan:\n {:?}", plan); Ok(new_plan) } + + /// Get a new TaskContext to run in this session + pub fn task_ctx(&self) -> Arc { + Arc::new(TaskContext::from(self)) + } } -impl From>> for ExecutionContext { - fn from(state: Arc>) -> Self { - ExecutionContext { state } +impl From>> for SessionContext { + fn from(state: Arc>) -> Self { + let session_id = state.lock().session_id.clone(); + SessionContext { session_id, state } } } -impl FunctionRegistry for ExecutionContext { +impl FunctionRegistry for SessionContext { fn udfs(&self) -> HashSet { self.state.lock().udfs() } @@ -782,7 +794,7 @@ pub trait QueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; } @@ -795,16 +807,18 @@ impl QueryPlanner for DefaultQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { let planner = DefaultPhysicalPlanner::default(); - planner.create_physical_plan(logical_plan, ctx_state).await + planner + .create_physical_plan(logical_plan, session_state) + .await } } /// Configuration options for execution context #[derive(Clone)] -pub struct ExecutionConfig { +pub struct SessionConfig { /// Number of partitions for query execution. Increasing partitions can increase concurrency. pub target_partitions: usize, /// Responsible for optimizing a logical plan @@ -837,7 +851,7 @@ pub struct ExecutionConfig { pub runtime: RuntimeConfig, } -impl Default for ExecutionConfig { +impl Default for SessionConfig { fn default() -> Self { Self { target_partitions: num_cpus::get(), @@ -878,7 +892,7 @@ impl Default for ExecutionConfig { } } -impl ExecutionConfig { +impl SessionConfig { /// Create an execution config with default setting pub fn new() -> Self { Default::default() @@ -1105,7 +1119,9 @@ impl ExecutionProps { /// Execution context for registering data sources and executing queries #[derive(Clone)] -pub struct ExecutionContextState { +pub struct SessionState { + /// Uuid for the session + session_id: String, /// Collection of catalogs containing schemas and ultimately TableProviders pub catalog_list: Arc, /// Scalar functions that are registered with the context @@ -1113,7 +1129,7 @@ pub struct ExecutionContextState { /// Aggregate functions registered in the context pub aggregate_functions: HashMap>, /// Context configuration - pub config: ExecutionConfig, + pub config: SessionConfig, /// Execution properties pub execution_props: ExecutionProps, /// Object Store that are registered with the context @@ -1122,20 +1138,22 @@ pub struct ExecutionContextState { pub runtime_env: Arc, } -impl Default for ExecutionContextState { +impl Default for SessionState { fn default() -> Self { Self::new() } } -impl ExecutionContextState { - /// Returns new ExecutionContextState +impl SessionState { + /// Returns new SessionState pub fn new() -> Self { - ExecutionContextState { + let session_id = Uuid::new_v4().to_string(); + SessionState { + session_id, catalog_list: Arc::new(MemoryCatalogList::new()), scalar_functions: HashMap::new(), aggregate_functions: HashMap::new(), - config: ExecutionConfig::new(), + config: SessionConfig::new(), execution_props: ExecutionProps::new(), object_store_registry: Arc::new(ObjectStoreRegistry::new()), runtime_env: Arc::new(RuntimeEnv::default()), @@ -1175,7 +1193,7 @@ impl ExecutionContextState { } } -impl ContextProvider for ExecutionContextState { +impl ContextProvider for SessionState { fn get_table_provider(&self, name: TableReference) -> Option> { let resolved_ref = self.resolve_table_ref(name); let schema = self.schema_for_ref(resolved_ref).ok()?; @@ -1208,7 +1226,7 @@ impl ContextProvider for ExecutionContextState { } } -impl FunctionRegistry for ExecutionContextState { +impl FunctionRegistry for SessionState { fn udfs(&self) -> HashSet { self.scalar_functions.keys().cloned().collect() } @@ -1236,6 +1254,74 @@ impl FunctionRegistry for ExecutionContextState { } } +/// Task Context Properties +pub enum TaskProperties { + ///SessionConfig + SessionConfig(SessionConfig), + /// Name-value pairs of task properties + KVPairs(HashMap), +} + +/// Task Execution Context +pub struct TaskContext { + /// Session Id + pub session_id: String, + /// Optional Task Identify + pub task_id: Option, + /// Task settings + pub task_settings: TaskProperties, + /// Runtime environment associated with this task context + pub runtime: Arc, +} + +impl TaskContext { + /// Create a new task context instance + pub fn new( + task_id: String, + session_id: String, + task_settings: HashMap, + runtime: Arc, + ) -> Self { + Self { + task_id: Some(task_id), + session_id, + task_settings: TaskProperties::KVPairs(task_settings), + runtime, + } + } +} + +/// Create a new task context instance from SessionContext +impl From<&SessionContext> for TaskContext { + fn from(session: &SessionContext) -> Self { + let state_clone = session.state.lock().clone(); + let session_id = session.session_id.clone(); + let config = state_clone.config; + let runtime = session.runtime_env(); + Self { + task_id: None, + session_id, + task_settings: TaskProperties::SessionConfig(config), + runtime, + } + } +} + +/// Create a new task context instance from SessionState +impl From<&SessionState> for TaskContext { + fn from(state: &SessionState) -> Self { + let session_id = state.session_id.clone(); + let config = state.config.clone(); + let runtime = state.runtime_env.clone(); + Self { + task_id: None, + session_id, + task_settings: TaskProperties::SessionConfig(config), + runtime, + } + } +} + #[cfg(test)] mod tests { use super::*; @@ -1271,16 +1357,16 @@ mod tests { async fn shared_memory_and_disk_manager() { // Demonstrate the ability to share DiskManager and // MemoryManager between two different executions. - let ctx1 = ExecutionContext::new(); + let ctx1 = SessionContext::new(); // configure with same memory / disk manager let memory_manager = ctx1.runtime_env().memory_manager.clone(); let disk_manager = ctx1.runtime_env().disk_manager.clone(); - let config = ExecutionConfig::new() + let config = SessionConfig::new() .with_existing_memory_manager(memory_manager.clone()) .with_existing_disk_manager(disk_manager.clone()); - let ctx2 = ExecutionContext::with_config(config); + let ctx2 = SessionContext::with_config(config); assert!(std::ptr::eq( Arc::as_ptr(&memory_manager), @@ -1594,7 +1680,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_min() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1618,7 +1704,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1643,7 +1729,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_sum() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1667,7 +1753,7 @@ mod tests { #[tokio::test] async fn aggregate_decimal_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // the data type of c1 is decimal(10,3) ctx.register_table("d_table", test::table_with_decimal()) .unwrap(); @@ -1982,7 +2068,7 @@ mod tests { #[tokio::test] async fn group_by_date_trunc() -> Result<()> { let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c2", DataType::UInt64, false), Field::new( @@ -2033,7 +2119,7 @@ mod tests { #[tokio::test] async fn group_by_largeutf8() { { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // input data looks like: // A, 1 @@ -2083,7 +2169,7 @@ mod tests { #[tokio::test] async fn unprojected_filter() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let df = ctx .read_table(test::table_with_sequence(1, 3).unwrap()) .unwrap(); @@ -2108,7 +2194,7 @@ mod tests { #[tokio::test] async fn group_by_dictionary() { async fn run_test_case() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // input data looks like: // A, 1 @@ -2205,7 +2291,7 @@ mod tests { partitions: Vec>, ) -> Result> { let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c_group", DataType::Utf8, false), Field::new("c_int8", DataType::Int8, false), @@ -2428,7 +2514,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_functions() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2468,7 +2554,7 @@ mod tests { #[tokio::test] async fn case_builtin_math_expression() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let type_values = vec![ ( @@ -2538,7 +2624,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2579,7 +2665,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_aggregates() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2619,7 +2705,7 @@ mod tests { #[tokio::test] async fn case_sensitive_identifiers_user_defined_aggregates() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", test::table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -2666,7 +2752,7 @@ mod tests { // The main stipulation of this test: use a file extension that isn't .csv. let file_extension = ".tst"; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = populate_csv_partitions(&tmp_dir, 2, file_extension)?; ctx.register_csv( "test", @@ -2695,7 +2781,7 @@ mod tests { #[tokio::test] async fn send_context_to_threads() -> Result<()> { - // ensure ExecutionContexts can be used in a multi-threaded + // ensure SessionContexts can be used in a multi-threaded // environment. Usecase is for concurrent planing. let tmp_dir = TempDir::new()?; let partition_count = 4; @@ -2722,7 +2808,7 @@ mod tests { #[tokio::test] async fn ctx_sql_should_optimize_plan() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let plan1 = ctx .create_logical_plan("SELECT * FROM (SELECT 1) AS one WHERE TRUE AND TRUE")?; @@ -2753,7 +2839,7 @@ mod tests { vec![Arc::new(Int32Array::from_slice(&[4, 5]))], )?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; @@ -2778,8 +2864,8 @@ mod tests { #[tokio::test] async fn custom_query_planner() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_query_planner(Arc::new(MyQueryPlanner {})), + let mut ctx = SessionContext::with_config( + SessionConfig::new().with_query_planner(Arc::new(MyQueryPlanner {})), ); let df = ctx.sql("SELECT 1").await?; @@ -2789,8 +2875,8 @@ mod tests { #[tokio::test] async fn disabled_default_catalog_and_schema() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().create_default_catalog_and_schema(false), + let mut ctx = SessionContext::with_config( + SessionConfig::new().create_default_catalog_and_schema(false), ); assert!(matches!( @@ -2808,8 +2894,8 @@ mod tests { #[tokio::test] async fn custom_catalog_and_schema() -> Result<()> { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new() + let mut ctx = SessionContext::with_config( + SessionConfig::new() .create_default_catalog_and_schema(false) .with_default_catalog_and_schema("my_catalog", "my_schema"), ); @@ -2842,7 +2928,7 @@ mod tests { #[tokio::test] async fn cross_catalog_access() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let catalog_a = MemoryCatalogProvider::new(); let schema_a = MemorySchemaProvider::new(); @@ -2887,8 +2973,8 @@ mod tests { #[tokio::test] async fn catalogs_not_leaked() { // the information schema used to introduce cyclic Arcs - let ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), + let ctx = SessionContext::with_config( + SessionConfig::new().with_information_schema(true), ); // register a single catalog @@ -2910,7 +2996,7 @@ mod tests { #[tokio::test] async fn normalized_column_identifiers() { // create local execution context - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // register csv file with the execution context ctx.register_csv( @@ -3081,7 +3167,7 @@ mod tests { async fn create_physical_plan( &self, _logical_plan: &LogicalPlan, - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result> { Err(DataFusionError::NotImplemented( "query not supported".to_string(), @@ -3093,7 +3179,7 @@ mod tests { _expr: &Expr, _input_dfschema: &crate::logical_plan::DFSchema, _input_schema: &Schema, - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result> { unimplemented!() } @@ -3106,18 +3192,18 @@ mod tests { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { let physical_planner = MyPhysicalPlanner {}; physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } /// Execute SQL and return results async fn plan_and_collect( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await @@ -3163,10 +3249,9 @@ mod tests { async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, - ) -> Result { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + ) -> Result { + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; @@ -3195,19 +3280,19 @@ mod tests { #[async_trait] impl CallReadTrait for CallRead { async fn call_read_csv(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.read_csv("dummy", CsvReadOptions::new()).await.unwrap() } async fn call_read_avro(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.read_avro("dummy", AvroReadOptions::default()) .await .unwrap() } async fn call_read_parquet(&self) -> Arc { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.read_parquet("dummy").await.unwrap() } } diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs index e514c0ad0cd41..2d396a61b325e 100644 --- a/datafusion/src/lib.rs +++ b/datafusion/src/lib.rs @@ -34,7 +34,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = ExecutionContext::new(); +//! let mut ctx = SessionContext::new(); //! //! // create the dataframe //! let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; @@ -73,7 +73,7 @@ //! //! # #[tokio::main] //! # async fn main() -> Result<()> { -//! let mut ctx = ExecutionContext::new(); +//! let mut ctx = SessionContext::new(); //! //! ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; //! @@ -111,7 +111,7 @@ //! 3. The planner [`SqlToRel`](sql::planner::SqlToRel) converts logical nodes on the AST to a [`LogicalPlan`](logical_plan::LogicalPlan). //! 4. [`OptimizerRules`](optimizer::optimizer::OptimizerRule) are applied to the [`LogicalPlan`](logical_plan::LogicalPlan) to optimize it. //! 5. The [`LogicalPlan`](logical_plan::LogicalPlan) is converted to an [`ExecutionPlan`](physical_plan::ExecutionPlan) by a [`PhysicalPlanner`](physical_plan::PhysicalPlanner) -//! 6. The [`ExecutionPlan`](physical_plan::ExecutionPlan) is executed against data through the [`ExecutionContext`](execution::context::ExecutionContext) +//! 6. The [`ExecutionPlan`](physical_plan::ExecutionPlan) is executed against data through the [`SessionContext`](execution::context::SessionContext) //! //! With a [`DataFrame`](dataframe::DataFrame) API, steps 1-3 are not used as the DataFrame builds the [`LogicalPlan`](logical_plan::LogicalPlan) directly. //! diff --git a/datafusion/src/physical_optimizer/aggregate_statistics.rs b/datafusion/src/physical_optimizer/aggregate_statistics.rs index 4ae6ce3638cc9..8a7d790b4e1bd 100644 --- a/datafusion/src/physical_optimizer/aggregate_statistics.rs +++ b/datafusion/src/physical_optimizer/aggregate_statistics.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::physical_plan::empty::EmptyExec; use crate::physical_plan::hash_aggregate::{AggregateMode, HashAggregateExec}; use crate::physical_plan::projection::ProjectionExec; @@ -48,7 +48,7 @@ impl PhysicalOptimizerRule for AggregateStatistics { fn optimize( &self, plan: Arc, - execution_config: &ExecutionConfig, + config: &SessionConfig, ) -> Result> { if let Some(partial_agg_exec) = take_optimizable(&*plan) { let partial_agg_exec = partial_agg_exec @@ -84,10 +84,10 @@ impl PhysicalOptimizerRule for AggregateStatistics { Arc::new(EmptyExec::new(true, Arc::new(Schema::empty()))), )?)) } else { - optimize_children(self, plan, execution_config) + optimize_children(self, plan, config) } } else { - optimize_children(self, plan, execution_config) + optimize_children(self, plan, config) } } @@ -259,7 +259,6 @@ mod tests { use arrow::record_batch::RecordBatch; use crate::error::Result; - use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::Operator; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; @@ -267,6 +266,7 @@ mod tests { use crate::physical_plan::filter::FilterExec; use crate::physical_plan::hash_aggregate::HashAggregateExec; use crate::physical_plan::memory::MemoryExec; + use crate::prelude::SessionContext; /// Mock data using a MemoryExec which has an exact count statistic fn mock_data() -> Result> { @@ -295,8 +295,9 @@ mod tests { plan: HashAggregateExec, nulls: bool, ) -> Result<()> { - let conf = ExecutionConfig::new(); - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let conf = session_ctx.state.lock().clone().config; let optimized = AggregateStatistics::new().optimize(Arc::new(plan), &conf)?; let (col, count) = match nulls { @@ -306,7 +307,7 @@ mod tests { // A ProjectionExec is a sign that the count optimization was applied assert!(optimized.as_any().is::()); - let result = common::collect(optimized.execute(0, runtime).await?).await?; + let result = common::collect(optimized.execute(0, task_ctx).await?).await?; assert_eq!(result[0].schema(), Arc::new(Schema::new(vec![col]))); assert_eq!( result[0] @@ -473,7 +474,7 @@ mod tests { Arc::clone(&schema), )?; - let conf = ExecutionConfig::new(); + let conf = SessionConfig::new(); let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; @@ -515,7 +516,7 @@ mod tests { Arc::clone(&schema), )?; - let conf = ExecutionConfig::new(); + let conf = SessionConfig::new(); let optimized = AggregateStatistics::new().optimize(Arc::new(final_agg), &conf)?; diff --git a/datafusion/src/physical_optimizer/coalesce_batches.rs b/datafusion/src/physical_optimizer/coalesce_batches.rs index 98e65a2b12816..47d87d35f0e65 100644 --- a/datafusion/src/physical_optimizer/coalesce_batches.rs +++ b/datafusion/src/physical_optimizer/coalesce_batches.rs @@ -42,7 +42,7 @@ impl PhysicalOptimizerRule for CoalesceBatches { fn optimize( &self, plan: Arc, - config: &crate::execution::context::ExecutionConfig, + config: &crate::execution::context::SessionConfig, ) -> Result> { // wrap operators in CoalesceBatches to avoid lots of tiny batches when we have // highly selective filters diff --git a/datafusion/src/physical_optimizer/hash_build_probe_order.rs b/datafusion/src/physical_optimizer/hash_build_probe_order.rs index 244eb6a560b6d..565a955235315 100644 --- a/datafusion/src/physical_optimizer/hash_build_probe_order.rs +++ b/datafusion/src/physical_optimizer/hash_build_probe_order.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use arrow::datatypes::Schema; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::logical_plan::JoinType; use crate::physical_plan::cross_join::CrossJoinExec; use crate::physical_plan::expressions::Column; @@ -113,9 +113,9 @@ impl PhysicalOptimizerRule for HashBuildProbeOrder { fn optimize( &self, plan: Arc, - execution_config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { - let plan = optimize_children(self, plan, execution_config)?; + let plan = optimize_children(self, plan, session_config)?; if let Some(hash_join) = plan.as_any().downcast_ref::() { let left = hash_join.left(); let right = hash_join.right(); @@ -212,7 +212,7 @@ mod tests { .unwrap(); let optimized_join = HashBuildProbeOrder::new() - .optimize(Arc::new(join), &ExecutionConfig::new()) + .optimize(Arc::new(join), &SessionConfig::new()) .unwrap(); let swapping_projection = optimized_join @@ -259,7 +259,7 @@ mod tests { .unwrap(); let optimized_join = HashBuildProbeOrder::new() - .optimize(Arc::new(join), &ExecutionConfig::new()) + .optimize(Arc::new(join), &SessionConfig::new()) .unwrap(); let swapped_join = optimized_join diff --git a/datafusion/src/physical_optimizer/merge_exec.rs b/datafusion/src/physical_optimizer/merge_exec.rs index 58823a665b16e..f23da15e51529 100644 --- a/datafusion/src/physical_optimizer/merge_exec.rs +++ b/datafusion/src/physical_optimizer/merge_exec.rs @@ -40,7 +40,7 @@ impl PhysicalOptimizerRule for AddCoalescePartitionsExec { fn optimize( &self, plan: Arc, - config: &crate::execution::context::ExecutionConfig, + config: &crate::execution::context::SessionConfig, ) -> Result> { if plan.children().is_empty() { // leaf node, children cannot be replaced diff --git a/datafusion/src/physical_optimizer/optimizer.rs b/datafusion/src/physical_optimizer/optimizer.rs index e2f40ae954024..741fb48763e86 100644 --- a/datafusion/src/physical_optimizer/optimizer.rs +++ b/datafusion/src/physical_optimizer/optimizer.rs @@ -20,7 +20,7 @@ use std::sync::Arc; use crate::{ - error::Result, execution::context::ExecutionConfig, physical_plan::ExecutionPlan, + error::Result, execution::context::SessionConfig, physical_plan::ExecutionPlan, }; /// `PhysicalOptimizerRule` transforms one ['ExecutionPlan'] into another which @@ -31,7 +31,7 @@ pub trait PhysicalOptimizerRule { fn optimize( &self, plan: Arc, - config: &ExecutionConfig, + config: &SessionConfig, ) -> Result>; /// A human readable name for this optimizer rule diff --git a/datafusion/src/physical_optimizer/repartition.rs b/datafusion/src/physical_optimizer/repartition.rs index ae074d2893daf..00a9b9e610797 100644 --- a/datafusion/src/physical_optimizer/repartition.rs +++ b/datafusion/src/physical_optimizer/repartition.rs @@ -21,7 +21,7 @@ use std::sync::Arc; use super::optimizer::PhysicalOptimizerRule; use crate::physical_plan::Partitioning::*; use crate::physical_plan::{repartition::RepartitionExec, ExecutionPlan}; -use crate::{error::Result, execution::context::ExecutionConfig}; +use crate::{error::Result, execution::context::SessionConfig}; /// Optimizer that introduces repartition to introduce more /// parallelism in the plan @@ -218,7 +218,7 @@ impl PhysicalOptimizerRule for Repartition { fn optimize( &self, plan: Arc, - config: &ExecutionConfig, + config: &SessionConfig, ) -> Result> { // Don't run optimizer if target_partitions == 1 if config.target_partitions == 1 { @@ -343,7 +343,7 @@ mod tests { // run optimizer let optimizer = Repartition {}; let optimized = optimizer - .optimize($PLAN, &ExecutionConfig::new().with_target_partitions(10))?; + .optimize($PLAN, &SessionConfig::new().with_target_partitions(10))?; // Now format correctly let plan = displayable(optimized.as_ref()).indent().to_string(); diff --git a/datafusion/src/physical_optimizer/utils.rs b/datafusion/src/physical_optimizer/utils.rs index 962b8ce14557b..bb1415e18beeb 100644 --- a/datafusion/src/physical_optimizer/utils.rs +++ b/datafusion/src/physical_optimizer/utils.rs @@ -18,7 +18,7 @@ //! Collection of utility functions that are leveraged by the query optimizer rules use super::optimizer::PhysicalOptimizerRule; -use crate::execution::context::ExecutionConfig; +use crate::execution::context::SessionConfig; use crate::error::Result; use crate::physical_plan::ExecutionPlan; @@ -31,12 +31,12 @@ use std::sync::Arc; pub fn optimize_children( optimizer: &impl PhysicalOptimizerRule, plan: Arc, - execution_config: &ExecutionConfig, + session_config: &SessionConfig, ) -> Result> { let children = plan .children() .iter() - .map(|child| optimizer.optimize(Arc::clone(child), execution_config)) + .map(|child| optimizer.optimize(Arc::clone(child), session_config)) .collect::>>()?; if children.is_empty() { diff --git a/datafusion/src/physical_plan/analyze.rs b/datafusion/src/physical_plan/analyze.rs index 6857ad532273b..31f47a673359e 100644 --- a/datafusion/src/physical_plan/analyze.rs +++ b/datafusion/src/physical_plan/analyze.rs @@ -32,7 +32,7 @@ use futures::StreamExt; use super::expressions::PhysicalSortExpr; use super::{stream::RecordBatchReceiverStream, Distribution, SendableRecordBatchStream}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// `EXPLAIN ANALYZE` execution plan operator. This operator runs its input, @@ -112,7 +112,7 @@ impl ExecutionPlan for AnalyzeExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -133,7 +133,7 @@ impl ExecutionPlan for AnalyzeExec { let (tx, rx) = tokio::sync::mpsc::channel(input_partitions); let captured_input = self.input.clone(); - let mut input_stream = captured_input.execute(0, runtime).await?; + let mut input_stream = captured_input.execute(0, context).await?; let captured_schema = self.schema.clone(); let verbose = self.verbose; @@ -238,6 +238,7 @@ mod tests { use arrow::datatypes::{DataType, Field, Schema}; use futures::FutureExt; + use crate::prelude::SessionContext; use crate::{ physical_plan::collect, test::{ @@ -250,7 +251,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -258,7 +260,7 @@ mod tests { let refs = blocking_exec.refs(); let analyze_exec = Arc::new(AnalyzeExec::new(true, blocking_exec, schema)); - let fut = collect(analyze_exec, runtime); + let fut = collect(analyze_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/coalesce_batches.rs b/datafusion/src/physical_plan/coalesce_batches.rs index 0d6fe38636f66..18785354c6ec3 100644 --- a/datafusion/src/physical_plan/coalesce_batches.rs +++ b/datafusion/src/physical_plan/coalesce_batches.rs @@ -29,7 +29,7 @@ use crate::physical_plan::{ SendableRecordBatchStream, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use arrow::compute::kernels::concat::concat; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; @@ -124,10 +124,10 @@ impl ExecutionPlan for CoalesceBatchesExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { Ok(Box::pin(CoalesceBatchesStream { - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, schema: self.input.schema(), target_batch_size: self.target_batch_size, buffer: Vec::new(), @@ -305,6 +305,7 @@ pub fn concat_batches( mod tests { use super::*; use crate::physical_plan::{memory::MemoryExec, repartition::RepartitionExec}; + use crate::prelude::SessionContext; use crate::test::create_vec_batches; use arrow::datatypes::{DataType, Field, Schema}; @@ -348,10 +349,11 @@ mod tests { // execute and collect results let output_partition_count = exec.output_partitioning().partition_count(); let mut output_partitions = Vec::with_capacity(output_partition_count); - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); for i in 0..output_partition_count { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let task_ctx = session_ctx.task_ctx(); + let mut stream = exec.execute(i, task_ctx.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); diff --git a/datafusion/src/physical_plan/coalesce_partitions.rs b/datafusion/src/physical_plan/coalesce_partitions.rs index 20b5487337159..0e550f5ef7d68 100644 --- a/datafusion/src/physical_plan/coalesce_partitions.rs +++ b/datafusion/src/physical_plan/coalesce_partitions.rs @@ -38,7 +38,7 @@ use crate::error::{DataFusionError, Result}; use crate::physical_plan::{DisplayFormatType, ExecutionPlan, Partitioning}; use super::SendableRecordBatchStream; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::common::spawn_execution; use pin_project_lite::pin_project; @@ -110,7 +110,7 @@ impl ExecutionPlan for CoalescePartitionsExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // CoalescePartitionsExec produces a single partition if 0 != partition { @@ -127,7 +127,7 @@ impl ExecutionPlan for CoalescePartitionsExec { )), 1 => { // bypass any threading / metrics if there is a single partition - self.input.execute(0, runtime).await + self.input.execute(0, context).await } _ => { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -150,7 +150,7 @@ impl ExecutionPlan for CoalescePartitionsExec { self.input.clone(), sender.clone(), part_i, - runtime.clone(), + context.clone(), )); } @@ -224,13 +224,15 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, common}; + use crate::prelude::SessionContext; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; use crate::test_util; #[tokio::test] async fn merge() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let num_partitions = 4; @@ -259,7 +261,7 @@ mod tests { assert_eq!(merge.output_partitioning().partition_count(), 1); // the result should contain 4 batches (one per input partition) - let iter = merge.execute(0, runtime).await?; + let iter = merge.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; assert_eq!(batches.len(), num_partitions); @@ -272,7 +274,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -281,7 +284,7 @@ mod tests { let coaelesce_partitions_exec = Arc::new(CoalescePartitionsExec::new(blocking_exec)); - let fut = collect(coaelesce_partitions_exec, runtime); + let fut = collect(coaelesce_partitions_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/common.rs b/datafusion/src/physical_plan/common.rs index bc4400d981862..b7313b9f25fd6 100644 --- a/datafusion/src/physical_plan/common.rs +++ b/datafusion/src/physical_plan/common.rs @@ -19,7 +19,7 @@ use super::{RecordBatchStream, SendableRecordBatchStream}; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::metrics::MemTrackingMetrics; use crate::physical_plan::{ColumnStatistics, ExecutionPlan, Statistics}; use arrow::compute::concat; @@ -176,10 +176,10 @@ pub(crate) fn spawn_execution( input: Arc, mut output: mpsc::Sender>, partition: usize, - runtime: Arc, + context: Arc, ) -> JoinHandle<()> { tokio::spawn(async move { - let mut stream = match input.execute(partition, runtime).await { + let mut stream = match input.execute(partition, context).await { Err(e) => { // If send fails, plan being torn // down, no place to send the error diff --git a/datafusion/src/physical_plan/cross_join.rs b/datafusion/src/physical_plan/cross_join.rs index 82ee5618f5f06..efe8224211785 100644 --- a/datafusion/src/physical_plan/cross_join.rs +++ b/datafusion/src/physical_plan/cross_join.rs @@ -43,7 +43,7 @@ use super::{ coalesce_batches::concat_batches, memory::MemoryStream, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use log::debug; /// Data of the left side @@ -149,7 +149,7 @@ impl ExecutionPlan for CrossJoinExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // we only want to compute the build side once let left_data = { @@ -162,7 +162,7 @@ impl ExecutionPlan for CrossJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0, runtime.clone()).await?; + let stream = merge.execute(0, context.clone()).await?; // Load all batches and count the rows let (batches, num_rows) = stream @@ -187,7 +187,7 @@ impl ExecutionPlan for CrossJoinExec { } }; - let stream = self.right.execute(partition, runtime.clone()).await?; + let stream = self.right.execute(partition, context.clone()).await?; if left_data.num_rows() == 0 { return Ok(Box::pin(MemoryStream::try_new( diff --git a/datafusion/src/physical_plan/empty.rs b/datafusion/src/physical_plan/empty.rs index 045026b70ed58..0fbf18861d912 100644 --- a/datafusion/src/physical_plan/empty.rs +++ b/datafusion/src/physical_plan/empty.rs @@ -31,7 +31,7 @@ use arrow::record_batch::RecordBatch; use super::expressions::PhysicalSortExpr; use super::{common, SendableRecordBatchStream, Statistics}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// Execution plan for empty relation (produces no rows) @@ -121,7 +121,7 @@ impl ExecutionPlan for EmptyExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { @@ -161,18 +161,20 @@ impl ExecutionPlan for EmptyExec { #[cfg(test)] mod tests { use super::*; + use crate::prelude::SessionContext; use crate::{physical_plan::common, test_util}; #[tokio::test] async fn empty() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema.clone()); assert_eq!(empty.schema(), schema); // we should have no results - let iter = empty.execute(0, runtime).await?; + let iter = empty.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; assert!(batches.is_empty()); @@ -201,23 +203,25 @@ mod tests { #[tokio::test] async fn invalid_execute() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(false, schema); // ask for the wrong partition - assert!(empty.execute(1, runtime.clone()).await.is_err()); - assert!(empty.execute(20, runtime.clone()).await.is_err()); + assert!(empty.execute(1, task_ctx.clone()).await.is_err()); + assert!(empty.execute(20, task_ctx).await.is_err()); Ok(()) } #[tokio::test] async fn produce_one_row() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let empty = EmptyExec::new(true, schema); - let iter = empty.execute(0, runtime).await?; + let iter = empty.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; // should have one item diff --git a/datafusion/src/physical_plan/explain.rs b/datafusion/src/physical_plan/explain.rs index 0955655a19293..d09eae3aef4f6 100644 --- a/datafusion/src/physical_plan/explain.rs +++ b/datafusion/src/physical_plan/explain.rs @@ -31,7 +31,7 @@ use crate::{ use arrow::{array::StringBuilder, datatypes::SchemaRef, record_batch::RecordBatch}; use super::{expressions::PhysicalSortExpr, SendableRecordBatchStream}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::metrics::{ExecutionPlanMetricsSet, MemTrackingMetrics}; use async_trait::async_trait; @@ -114,7 +114,7 @@ impl ExecutionPlan for ExplainExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( diff --git a/datafusion/src/physical_plan/file_format/avro.rs b/datafusion/src/physical_plan/file_format/avro.rs index ba0873d78b2b3..89cae4e143e5a 100644 --- a/datafusion/src/physical_plan/file_format/avro.rs +++ b/datafusion/src/physical_plan/file_format/avro.rs @@ -27,7 +27,7 @@ use arrow::datatypes::SchemaRef; #[cfg(feature = "avro")] use arrow::error::ArrowError; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use std::any::Any; use std::sync::Arc; @@ -105,7 +105,7 @@ impl ExecutionPlan for AvroExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::NotImplemented( "Cannot execute avro plan without avro feature enabled".to_string(), @@ -116,11 +116,11 @@ impl ExecutionPlan for AvroExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = runtime.batch_size(); + let batch_size = context.runtime.batch_size(); let file_schema = Arc::clone(&self.base_config.file_schema); // The avro reader cannot limit the number of records, so `remaining` is ignored. diff --git a/datafusion/src/physical_plan/file_format/csv.rs b/datafusion/src/physical_plan/file_format/csv.rs index d9f4706fdf0b0..0f38291e5ee3d 100644 --- a/datafusion/src/physical_plan/file_format/csv.rs +++ b/datafusion/src/physical_plan/file_format/csv.rs @@ -18,13 +18,12 @@ //! Execution plan for reading CSV files use crate::error::{DataFusionError, Result}; -use crate::execution::context::ExecutionContext; +use crate::execution::context::{SessionContext, TaskContext}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; use arrow::csv; use arrow::datatypes::SchemaRef; use async_trait::async_trait; @@ -123,9 +122,9 @@ impl ExecutionPlan for CsvExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let batch_size = runtime.batch_size(); + let batch_size = context.runtime.batch_size(); let file_schema = Arc::clone(&self.base_config.file_schema); let file_projection = self.base_config.file_column_projection_indices(); let has_header = self.has_header; @@ -181,14 +180,13 @@ impl ExecutionPlan for CsvExec { } pub async fn plan_to_csv( - context: &ExecutionContext, + context: &SessionContext, plan: Arc, path: impl AsRef, ) -> Result<()> { let path = path.as_ref(); // create directory to contain the CSV files (one per partition) let fs_path = Path::new(path); - let runtime = context.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -198,7 +196,8 @@ pub async fn plan_to_csv( let path = fs_path.join(&filename); let file = fs::File::create(path)?; let mut writer = csv::Writer::new(file); - let stream = plan.execute(i, runtime.clone()).await?; + let task_ctx = context.task_ctx(); + let stream = plan.execute(i, task_ctx).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -236,7 +235,8 @@ mod tests { #[tokio::test] async fn csv_exec_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -258,7 +258,7 @@ mod tests { assert_eq!(3, csv.projected_schema.fields().len()); assert_eq!(3, csv.schema().fields().len()); - let mut stream = csv.execute(0, runtime).await?; + let mut stream = csv.execute(0, task_ctx).await?; let batch = stream.next().await.unwrap()?; assert_eq!(3, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -282,7 +282,8 @@ mod tests { #[tokio::test] async fn csv_exec_with_limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -304,7 +305,7 @@ mod tests { assert_eq!(13, csv.projected_schema.fields().len()); assert_eq!(13, csv.schema().fields().len()); - let mut it = csv.execute(0, runtime).await?; + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(13, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -328,7 +329,8 @@ mod tests { #[tokio::test] async fn csv_exec_with_missing_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema_with_missing_col(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -350,7 +352,7 @@ mod tests { assert_eq!(14, csv.projected_schema.fields().len()); assert_eq!(14, csv.schema().fields().len()); - let mut it = csv.execute(0, runtime).await?; + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(14, batch.num_columns()); assert_eq!(5, batch.num_rows()); @@ -374,7 +376,8 @@ mod tests { #[tokio::test] async fn csv_exec_with_partition() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let file_schema = aggr_test_schema(); let testdata = crate::test_util::arrow_test_data(); let filename = "aggregate_test_100.csv"; @@ -403,7 +406,7 @@ mod tests { assert_eq!(2, csv.projected_schema.fields().len()); assert_eq!(2, csv.schema().fields().len()); - let mut it = csv.execute(0, runtime).await?; + let mut it = csv.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(2, batch.num_columns()); assert_eq!(100, batch.num_rows()); @@ -457,9 +460,8 @@ mod tests { async fn write_csv_results() -> Result<()> { // create partitioned input file and context let tmp_dir = TempDir::new()?; - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 8, ".csv")?; @@ -477,7 +479,7 @@ mod tests { df.write_csv(&out_dir).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::UInt32, false), diff --git a/datafusion/src/physical_plan/file_format/json.rs b/datafusion/src/physical_plan/file_format/json.rs index 6c5ffcd99eac0..9c3a2e1ecaff4 100644 --- a/datafusion/src/physical_plan/file_format/json.rs +++ b/datafusion/src/physical_plan/file_format/json.rs @@ -19,7 +19,7 @@ use async_trait::async_trait; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, @@ -95,11 +95,11 @@ impl ExecutionPlan for NdJsonExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let proj = self.base_config.projected_file_column_names(); - let batch_size = runtime.batch_size(); + let batch_size = context.runtime.batch_size(); let file_schema = Arc::clone(&self.base_config.file_schema); // The json reader cannot limit the number of records, so `remaining` is ignored. @@ -156,6 +156,7 @@ mod tests { local_object_reader_stream, local_unpartitioned_file, LocalFileSystem, }, }; + use crate::prelude::SessionContext; use super::*; @@ -169,7 +170,8 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); let exec = NdJsonExec::new(FileScanConfig { @@ -206,7 +208,7 @@ mod tests { &DataType::Utf8 ); - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -224,7 +226,8 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_with_missing_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); use arrow::datatypes::DataType; let path = format!("{}/1.json", TEST_DATA_BASE); @@ -246,7 +249,7 @@ mod tests { table_partition_cols: vec![], }); - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 3); @@ -265,7 +268,8 @@ mod tests { #[tokio::test] async fn nd_json_exec_file_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let path = format!("{}/1.json", TEST_DATA_BASE); let exec = NdJsonExec::new(FileScanConfig { object_store: Arc::new(LocalFileSystem {}), @@ -284,7 +288,7 @@ mod tests { inferred_schema.field_with_name("c").unwrap(); inferred_schema.field_with_name("d").unwrap_err(); - let mut it = exec.execute(0, runtime).await?; + let mut it = exec.execute(0, task_ctx).await?; let batch = it.next().await.unwrap()?; assert_eq!(batch.num_rows(), 4); diff --git a/datafusion/src/physical_plan/file_format/parquet.rs b/datafusion/src/physical_plan/file_format/parquet.rs index 2d23ca1c3ada5..f532f6a88d277 100644 --- a/datafusion/src/physical_plan/file_format/parquet.rs +++ b/datafusion/src/physical_plan/file_format/parquet.rs @@ -27,7 +27,7 @@ use std::{any::Any, convert::TryInto}; use crate::datasource::file_format::parquet::ChunkObjectReader; use crate::datasource::object_store::ObjectStore; use crate::datasource::PartitionedFile; -use crate::execution::context::ExecutionContext; +use crate::execution::context::{SessionContext, TaskContext}; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::{ error::{DataFusionError, Result}, @@ -68,7 +68,6 @@ use tokio::{ task, }; -use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::file_format::SchemaAdapter; use async_trait::async_trait; @@ -210,7 +209,7 @@ impl ExecutionPlan for ParquetExec { async fn execute( &self, partition_index: usize, - runtime: Arc, + context: Arc, ) -> Result { // because the parquet implementation is not thread-safe, it is necessary to execute // on a thread and communicate with channels @@ -226,7 +225,7 @@ impl ExecutionPlan for ParquetExec { None => (0..self.base_config.file_schema.fields().len()).collect(), }; let pruning_predicate = self.pruning_predicate.clone(); - let batch_size = runtime.batch_size(); + let batch_size = context.runtime.batch_size(); let limit = self.base_config.limit; let object_store = Arc::clone(&self.base_config.object_store); let partition_col_proj = PartitionColumnProjector::new( @@ -528,7 +527,7 @@ fn read_partition( /// Executes a query and writes the results to a partitioned Parquet file. pub async fn plan_to_parquet( - context: &ExecutionContext, + context: &SessionContext, plan: Arc, path: impl AsRef, writer_properties: Option, @@ -536,7 +535,6 @@ pub async fn plan_to_parquet( let path = path.as_ref(); // create directory to contain the Parquet files (one per partition) let fs_path = Path::new(path); - let runtime = context.runtime_env(); match fs::create_dir(fs_path) { Ok(()) => { let mut tasks = vec![]; @@ -550,7 +548,8 @@ pub async fn plan_to_parquet( plan.schema(), writer_properties.clone(), )?; - let stream = plan.execute(i, runtime.clone()).await?; + let task_ctx = context.task_ctx(); + let stream = plan.execute(i, task_ctx).await?; let handle: JoinHandle> = task::spawn(async move { stream .map(|batch| writer.write(&batch?)) @@ -589,7 +588,7 @@ mod tests { use super::*; use crate::execution::options::CsvReadOptions; - use crate::prelude::ExecutionConfig; + use crate::prelude::SessionConfig; use arrow::array::Float32Array; use arrow::{ array::{Int64Array, Int8Array, StringArray}, @@ -669,8 +668,9 @@ mod tests { None, ); - let runtime = Arc::new(RuntimeEnv::default()); - collect(Arc::new(parquet_exec), runtime).await + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + collect(Arc::new(parquet_exec), task_ctx).await } // Add a new column with the specified field name to the RecordBatch @@ -887,7 +887,8 @@ mod tests { #[tokio::test] async fn parquet_exec_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let parquet_exec = ParquetExec::new( @@ -906,7 +907,7 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0, runtime).await?; + let mut results = parquet_exec.execute(0, task_ctx).await?; let batch = results.next().await.unwrap()?; assert_eq!(8, batch.num_rows()); @@ -931,7 +932,8 @@ mod tests { #[tokio::test] async fn parquet_exec_with_partition() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let mut partitioned_file = local_unpartitioned_file(filename.clone()); @@ -961,7 +963,7 @@ mod tests { ); assert_eq!(parquet_exec.output_partitioning().partition_count(), 1); - let mut results = parquet_exec.execute(0, runtime).await?; + let mut results = parquet_exec.execute(0, task_ctx).await?; let batch = results.next().await.unwrap()?; let expected = vec![ "+----+----------+-------------+-------+", @@ -987,7 +989,8 @@ mod tests { #[tokio::test] async fn parquet_exec_with_error() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let testdata = crate::test_util::parquet_test_data(); let filename = format!("{}/alltypes_plain.parquet", testdata); let partitioned_file = PartitionedFile { @@ -1016,7 +1019,7 @@ mod tests { None, ); - let mut results = parquet_exec.execute(0, runtime).await?; + let mut results = parquet_exec.execute(0, task_ctx).await?; let batch = results.next().await.unwrap(); // invalid file should produce an error to that effect assert_contains!( @@ -1318,9 +1321,8 @@ mod tests { // create partitioned input file and context let tmp_dir = TempDir::new()?; // let mut ctx = create_ctx(&tmp_dir, 4).await?; - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_target_partitions(8), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(&tmp_dir, 4, ".csv")?; // register csv file with the execution context ctx.register_csv( @@ -1337,7 +1339,7 @@ mod tests { // write_parquet(&mut ctx, "SELECT c1, c2 FROM test", &out_dir, None).await?; // create a new context and verify that the results were saved to a partitioned csv file - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // register each partition as well as the top level dir ctx.register_parquet("part0", &format!("{}/part-0.parquet", out_dir)) diff --git a/datafusion/src/physical_plan/filter.rs b/datafusion/src/physical_plan/filter.rs index 69ff6bfc995be..689cbe244eec7 100644 --- a/datafusion/src/physical_plan/filter.rs +++ b/datafusion/src/physical_plan/filter.rs @@ -38,7 +38,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use futures::stream::{Stream, StreamExt}; /// FilterExec evaluates a boolean predicate against all input batches to determine which rows to @@ -136,14 +136,14 @@ impl ExecutionPlan for FilterExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); Ok(Box::pin(FilterExecStream { schema: self.input.schema().clone(), predicate: self.predicate.clone(), - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, baseline_metrics, })) } @@ -246,6 +246,7 @@ mod tests { use crate::physical_plan::expressions::*; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; + use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test; use crate::test_util; @@ -254,7 +255,8 @@ mod tests { #[tokio::test] async fn simple_predicate() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -295,7 +297,7 @@ mod tests { let filter: Arc = Arc::new(FilterExec::try_new(predicate, Arc::new(csv))?); - let results = collect(filter, runtime).await?; + let results = collect(filter, task_ctx).await?; results .iter() diff --git a/datafusion/src/physical_plan/hash_aggregate.rs b/datafusion/src/physical_plan/hash_aggregate.rs index 33d3bccbba53b..7476f77edd8bf 100644 --- a/datafusion/src/physical_plan/hash_aggregate.rs +++ b/datafusion/src/physical_plan/hash_aggregate.rs @@ -48,7 +48,7 @@ use arrow::{ use hashbrown::raw::RawTable; use pin_project_lite::pin_project; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use super::common::AbortOnDropSingle; @@ -233,9 +233,9 @@ impl ExecutionPlan for HashAggregateExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let input = self.input.execute(partition, runtime).await?; + let input = self.input.execute(partition, context).await?; let group_expr = self.group_expr.iter().map(|x| x.0.clone()).collect(); let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); @@ -1030,6 +1030,7 @@ mod tests { use futures::FutureExt; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; + use crate::prelude::SessionContext; /// some mock data to aggregates fn some_data() -> (Arc, Vec) { @@ -1076,7 +1077,8 @@ mod tests { DataType::Float64, ))]; - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let partial_aggregate = Arc::new(HashAggregateExec::try_new( AggregateMode::Partial, @@ -1087,7 +1089,8 @@ mod tests { )?); let result = - common::collect(partial_aggregate.execute(0, runtime.clone()).await?).await?; + common::collect(partial_aggregate.execute(0, task_ctx.clone()).await?) + .await?; let expected = vec![ "+---+---------------+-------------+", @@ -1119,7 +1122,7 @@ mod tests { )?); let result = - common::collect(merged_aggregate.execute(0, runtime.clone()).await?).await?; + common::collect(merged_aggregate.execute(0, task_ctx.clone()).await?).await?; assert_eq!(result.len(), 1); let batch = &result[0]; @@ -1187,7 +1190,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let stream = if self.yield_first { TestYieldingStream::New @@ -1264,7 +1267,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel_without_groups() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1286,7 +1290,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); + let fut = crate::physical_plan::collect(hash_aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1298,7 +1302,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel_with_groups() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float32, true), @@ -1323,7 +1328,7 @@ mod tests { schema, )?); - let fut = crate::physical_plan::collect(hash_aggregate_exec, runtime); + let fut = crate::physical_plan::collect(hash_aggregate_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index d276ac2e72de2..ec7e032f4b8e0 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -71,7 +71,7 @@ use super::{ }; use crate::arrow::array::BooleanBufferBuilder; use crate::arrow::datatypes::TimeUnit; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::coalesce_batches::concat_batches; use crate::physical_plan::PhysicalExpr; use log::debug; @@ -290,7 +290,7 @@ impl ExecutionPlan for HashJoinExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let on_left = self.on.iter().map(|on| on.0.clone()).collect::>(); // we only want to compute the build side once for PartitionMode::CollectLeft @@ -306,7 +306,7 @@ impl ExecutionPlan for HashJoinExec { // merge all left parts into a single stream let merge = CoalescePartitionsExec::new(self.left.clone()); - let stream = merge.execute(0, runtime.clone()).await?; + let stream = merge.execute(0, context.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -359,7 +359,7 @@ impl ExecutionPlan for HashJoinExec { let start = Instant::now(); // Load 1 partition of left side in memory - let stream = self.left.execute(partition, runtime.clone()).await?; + let stream = self.left.execute(partition, context.clone()).await?; // This operation performs 2 steps at once: // 1. creates a [JoinHashMap] of all batches from the stream @@ -410,7 +410,7 @@ impl ExecutionPlan for HashJoinExec { // we have the batches and the hash map with their keys. We can how create a stream // over the right that uses this information to issue new batches. - let right_stream = self.right.execute(partition, runtime.clone()).await?; + let right_stream = self.right.execute(partition, context.clone()).await?; let on_right = self.on.iter().map(|on| on.1.clone()).collect::>(); let num_rows = left_data.1.num_rows(); @@ -1063,6 +1063,7 @@ mod tests { }; use super::*; + use crate::prelude::SessionContext; use std::sync::Arc; fn build_table( @@ -1098,12 +1099,12 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, - runtime: Arc, + context: Arc, ) -> Result<(Vec, Vec)> { let join = join(left, right, on, join_type, null_equals_null)?; let columns = columns(&join.schema()); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, context).await?; let batches = common::collect(stream).await?; Ok((columns, batches)) @@ -1115,7 +1116,7 @@ mod tests { on: JoinOn, join_type: &JoinType, null_equals_null: bool, - runtime: Arc, + context: Arc, ) -> Result<(Vec, Vec)> { let partition_count = 4; @@ -1148,7 +1149,7 @@ mod tests { let mut batches = vec![]; for i in 0..partition_count { - let stream = join.execute(i, runtime.clone()).await?; + let stream = join.execute(i, context.clone()).await?; let more_batches = common::collect(stream).await?; batches.extend( more_batches @@ -1163,7 +1164,8 @@ mod tests { #[tokio::test] async fn join_inner_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1186,7 +1188,7 @@ mod tests { on.clone(), &JoinType::Inner, false, - runtime, + task_ctx, ) .await?; @@ -1208,7 +1210,8 @@ mod tests { #[tokio::test] async fn partitioned_join_inner_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1230,7 +1233,7 @@ mod tests { on.clone(), &JoinType::Inner, false, - runtime, + task_ctx, ) .await?; @@ -1252,7 +1255,8 @@ mod tests { #[tokio::test] async fn join_inner_one_no_shared_column_names() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1269,7 +1273,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); @@ -1290,7 +1294,8 @@ mod tests { #[tokio::test] async fn join_inner_two() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 2]), ("b2", &vec![1, 2, 2]), @@ -1313,7 +1318,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1337,7 +1342,8 @@ mod tests { /// Test where the left has 2 parts, the right with 1 part => 1 part #[tokio::test] async fn join_inner_one_two_parts_left() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let batch1 = build_table_i32( ("a1", &vec![1, 2]), ("b2", &vec![1, 2]), @@ -1367,7 +1373,7 @@ mod tests { ]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Inner, false, runtime).await?; + join_collect(left, right, on, &JoinType::Inner, false, task_ctx).await?; assert_eq!(columns, vec!["a1", "b2", "c1", "a1", "b2", "c2"]); @@ -1391,7 +1397,8 @@ mod tests { /// Test where the left has 1 part, the right has 2 parts => 2 parts #[tokio::test] async fn join_inner_one_two_parts_right() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 5]), // this has a repetition @@ -1421,7 +1428,7 @@ mod tests { assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); // first part - let stream = join.execute(0, runtime.clone()).await?; + let stream = join.execute(0, task_ctx.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); @@ -1435,7 +1442,7 @@ mod tests { assert_batches_sorted_eq!(expected, &batches); // second part - let stream = join.execute(1, runtime.clone()).await?; + let stream = join.execute(1, task_ctx.clone()).await?; let batches = common::collect(stream).await?; assert_eq!(batches.len(), 1); let expected = vec![ @@ -1466,7 +1473,8 @@ mod tests { #[tokio::test] async fn join_left_multi_batch() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1487,7 +1495,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1507,7 +1515,8 @@ mod tests { #[tokio::test] async fn join_full_multi_batch() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1529,7 +1538,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1551,7 +1560,8 @@ mod tests { #[tokio::test] async fn join_left_empty_right() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1569,7 +1579,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1587,7 +1597,8 @@ mod tests { #[tokio::test] async fn join_full_empty_right() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1605,7 +1616,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await.unwrap(); + let stream = join.execute(0, task_ctx).await.unwrap(); let batches = common::collect(stream).await.unwrap(); let expected = vec![ @@ -1623,7 +1634,8 @@ mod tests { #[tokio::test] async fn join_left_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1645,7 +1657,7 @@ mod tests { on.clone(), &JoinType::Left, false, - runtime, + task_ctx, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1666,7 +1678,8 @@ mod tests { #[tokio::test] async fn partitioned_join_left_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1688,7 +1701,7 @@ mod tests { on.clone(), &JoinType::Left, false, - runtime, + task_ctx, ) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1709,7 +1722,8 @@ mod tests { #[tokio::test] async fn join_semi() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 2, 3]), ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right @@ -1730,7 +1744,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1749,7 +1763,8 @@ mod tests { #[tokio::test] async fn join_anti() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 2, 3, 5]), ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right @@ -1770,7 +1785,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1"]); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1787,7 +1802,8 @@ mod tests { #[tokio::test] async fn join_right_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1804,7 +1820,7 @@ mod tests { )]; let (columns, batches) = - join_collect(left, right, on, &JoinType::Right, false, runtime).await?; + join_collect(left, right, on, &JoinType::Right, false, task_ctx).await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1825,7 +1841,8 @@ mod tests { #[tokio::test] async fn partitioned_join_right_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), @@ -1842,7 +1859,7 @@ mod tests { )]; let (columns, batches) = - partitioned_join_collect(left, right, on, &JoinType::Right, false, runtime) + partitioned_join_collect(left, right, on, &JoinType::Right, false, task_ctx) .await?; assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b1", "c2"]); @@ -1864,7 +1881,8 @@ mod tests { #[tokio::test] async fn join_full_one() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a1", &vec![1, 2, 3]), ("b1", &vec![4, 5, 7]), // 7 does not exist on the right @@ -1885,7 +1903,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a1", "b1", "c1", "a2", "b2", "c2"]); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ @@ -1955,7 +1973,8 @@ mod tests { #[tokio::test] async fn join_with_duplicated_column_names() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let left = build_table( ("a", &vec![1, 2, 3]), ("b", &vec![4, 5, 7]), @@ -1977,7 +1996,7 @@ mod tests { let columns = columns(&join.schema()); assert_eq!(columns, vec!["a", "b", "c", "a", "b", "c"]); - let stream = join.execute(0, runtime).await?; + let stream = join.execute(0, task_ctx).await?; let batches = common::collect(stream).await?; let expected = vec![ diff --git a/datafusion/src/physical_plan/limit.rs b/datafusion/src/physical_plan/limit.rs index f150c5601294c..f5de703eab105 100644 --- a/datafusion/src/physical_plan/limit.rs +++ b/datafusion/src/physical_plan/limit.rs @@ -41,7 +41,7 @@ use super::{ RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; /// Limit execution plan @@ -134,7 +134,7 @@ impl ExecutionPlan for GlobalLimitExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { @@ -152,7 +152,7 @@ impl ExecutionPlan for GlobalLimitExec { } let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(0, runtime).await?; + let stream = self.input.execute(0, context).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -285,10 +285,10 @@ impl ExecutionPlan for LocalLimitExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); - let stream = self.input.execute(partition, runtime).await?; + let stream = self.input.execute(partition, context).await?; Ok(Box::pin(LimitStream::new( stream, self.limit, @@ -432,11 +432,13 @@ mod tests { use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::common; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; + use crate::prelude::SessionContext; use crate::{test, test_util}; #[tokio::test] async fn limit() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let num_partitions = 4; @@ -464,7 +466,7 @@ mod tests { GlobalLimitExec::new(Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), 7); // the result should contain 4 batches (one per input partition) - let iter = limit.execute(0, runtime).await?; + let iter = limit.execute(0, task_ctx).await?; let batches = common::collect(iter).await?; // there should be a total of 100 rows diff --git a/datafusion/src/physical_plan/memory.rs b/datafusion/src/physical_plan/memory.rs index cc8208346516d..2662c551b7aa2 100644 --- a/datafusion/src/physical_plan/memory.rs +++ b/datafusion/src/physical_plan/memory.rs @@ -32,7 +32,7 @@ use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; use arrow::record_batch::RecordBatch; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use futures::Stream; @@ -99,7 +99,7 @@ impl ExecutionPlan for MemoryExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(MemoryStream::try_new( self.partitions[partition].clone(), @@ -223,6 +223,7 @@ mod tests { use super::*; use crate::from_slice::FromSlice; use crate::physical_plan::ColumnStatistics; + use crate::prelude::SessionContext; use arrow::array::Int32Array; use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; @@ -250,7 +251,8 @@ mod tests { #[tokio::test] async fn test_with_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, Some(vec![2, 1]))?; @@ -276,7 +278,7 @@ mod tests { ); // scan with projection - let mut it = executor.execute(0, runtime).await?; + let mut it = executor.execute(0, task_ctx).await?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); assert_eq!("c", batch2.schema().field(0).name()); @@ -288,7 +290,8 @@ mod tests { #[tokio::test] async fn test_without_projection() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let (schema, batch) = mock_data()?; let executor = MemoryExec::try_new(&[vec![batch]], schema, None)?; @@ -325,7 +328,7 @@ mod tests { ]) ); - let mut it = executor.execute(0, runtime).await?; + let mut it = executor.execute(0, task_ctx).await?; let batch1 = it.next().await.unwrap()?; assert_eq!(4, batch1.schema().fields().len()); assert_eq!(4, batch1.num_columns()); diff --git a/datafusion/src/physical_plan/mod.rs b/datafusion/src/physical_plan/mod.rs index e2ce99f2bdf4b..5e86db02b0b34 100644 --- a/datafusion/src/physical_plan/mod.rs +++ b/datafusion/src/physical_plan/mod.rs @@ -23,7 +23,7 @@ use self::{ coalesce_partitions::CoalescePartitionsExec, display::DisplayableExecutionPlan, }; use crate::physical_plan::expressions::PhysicalSortExpr; -use crate::{error::Result, execution::runtime_env::RuntimeEnv, scalar::ScalarValue}; +use crate::{error::Result, scalar::ScalarValue}; use arrow::datatypes::SchemaRef; use arrow::error::Result as ArrowResult; @@ -226,7 +226,7 @@ pub trait ExecutionPlan: Debug + Send + Sync { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result; /// Return a snapshot of the set of [`Metric`]s for this @@ -269,9 +269,9 @@ pub trait ExecutionPlan: Debug + Send + Sync { /// #[tokio::main] /// async fn main() { /// // Hard code target_partitions as it appears in the RepartitionExec output -/// let config = ExecutionConfig::new() +/// let config = SessionConfig::new() /// .with_target_partitions(3); -/// let mut ctx = ExecutionContext::with_config(config); +/// let mut ctx = SessionContext::with_config(config); /// /// // register the a table /// ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await.unwrap(); @@ -385,26 +385,26 @@ pub fn visit_execution_plan( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result> { - let stream = execute_stream(plan, runtime).await?; + let stream = execute_stream(plan, context).await?; common::collect(stream).await } /// Execute the [ExecutionPlan] and return a single stream of results pub async fn execute_stream( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result { match plan.output_partitioning().partition_count() { 0 => Ok(Box::pin(EmptyRecordBatchStream::new(plan.schema()))), - 1 => plan.execute(0, runtime).await, + 1 => plan.execute(0, context).await, _ => { // merge into a single partition let plan = CoalescePartitionsExec::new(plan.clone()); // CoalescePartitionsExec must produce a single partition assert_eq!(1, plan.output_partitioning().partition_count()); - plan.execute(0, runtime).await + plan.execute(0, context).await } } } @@ -412,9 +412,9 @@ pub async fn execute_stream( /// Execute the [ExecutionPlan] and collect the results in memory pub async fn collect_partitioned( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result>> { - let streams = execute_stream_partitioned(plan, runtime).await?; + let streams = execute_stream_partitioned(plan, context).await?; let mut batches = Vec::with_capacity(streams.len()); for stream in streams { batches.push(common::collect(stream).await?); @@ -425,12 +425,12 @@ pub async fn collect_partitioned( /// Execute the [ExecutionPlan] and return a vec with one stream per output partition pub async fn execute_stream_partitioned( plan: Arc, - runtime: Arc, + context: Arc, ) -> Result> { let num_partitions = plan.output_partitioning().partition_count(); let mut streams = Vec::with_capacity(num_partitions); for i in 0..num_partitions { - streams.push(plan.execute(i, runtime.clone()).await?); + streams.push(plan.execute(i, context.clone()).await?); } Ok(streams) } @@ -521,7 +521,9 @@ pub mod cross_join; pub mod display; pub mod empty; pub mod explain; +use crate::execution::context::TaskContext; pub use datafusion_physical_expr::expressions; + pub mod aggregate_rule; pub mod file_format; pub mod filter; diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index b3bcf37da6e02..37b303469b8ea 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -22,7 +22,7 @@ use super::{ aggregates, empty::EmptyExec, expressions::binary, functions, hash_join::PartitionMode, udaf, union::UnionExec, values::ValuesExec, windows, }; -use crate::execution::context::{ExecutionContextState, ExecutionProps}; +use crate::execution::context::{ExecutionProps, SessionState}; use crate::logical_plan::plan::{ Aggregate, EmptyRelation, Filter, Join, Projection, Sort, TableScan, Window, }; @@ -217,7 +217,7 @@ pub trait PhysicalPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; /// Create a physical expression from a logical expression @@ -233,7 +233,7 @@ pub trait PhysicalPlanner { expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>; } @@ -255,7 +255,7 @@ pub trait ExtensionPlanner { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>>; } @@ -272,13 +272,15 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { - match self.handle_explain(logical_plan, ctx_state).await? { + match self.handle_explain(logical_plan, session_state).await? { Some(plan) => Ok(plan), None => { - let plan = self.create_initial_plan(logical_plan, ctx_state).await?; - self.optimize_internal(plan, ctx_state, |_, _| {}) + let plan = self + .create_initial_plan(logical_plan, session_state) + .await?; + self.optimize_internal(plan, session_state, |_, _| {}) } } } @@ -296,13 +298,13 @@ impl PhysicalPlanner for DefaultPhysicalPlanner { expr: &Expr, input_dfschema: &DFSchema, input_schema: &Schema, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { create_physical_expr( expr, input_dfschema, input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) } } @@ -322,7 +324,7 @@ impl DefaultPhysicalPlanner { fn create_initial_plan<'a>( &'a self, logical_plan: &'a LogicalPlan, - ctx_state: &'a ExecutionContextState, + session_state: &'a SessionState, ) -> BoxFuture<'a, Result>> { async move { let exec_plan: Result> = match logical_plan { @@ -352,7 +354,7 @@ impl DefaultPhysicalPlanner { expr, schema, &exec_schema, - ctx_state, + session_state, ) }) .collect::>>>() @@ -373,15 +375,15 @@ impl DefaultPhysicalPlanner { )); } - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; // at this moment we are guaranteed by the logical planner // to have all the window_expr to have equal sort key let partition_keys = window_expr_common_partition_keys(window_expr)?; let can_repartition = !partition_keys.is_empty() - && ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_windows; + && session_state.config.target_partitions > 1 + && session_state.config.repartition_windows; let input_exec = if can_repartition { let partition_keys = partition_keys @@ -391,7 +393,7 @@ impl DefaultPhysicalPlanner { e, input.schema(), &input_exec.schema(), - ctx_state, + session_state, ) }) .collect::>>>()?; @@ -399,7 +401,7 @@ impl DefaultPhysicalPlanner { input_exec, Partitioning::Hash( partition_keys, - ctx_state.config.target_partitions, + session_state.config.target_partitions, ), )?) } else { @@ -446,7 +448,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - &ctx_state.execution_props, + &session_state.execution_props, ), _ => unreachable!(), }) @@ -466,7 +468,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) }) .collect::>>()?; @@ -484,7 +486,7 @@ impl DefaultPhysicalPlanner { .. }) => { // Initially need to perform the aggregate and then merge the partitions - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; let physical_input_schema = input_exec.schema(); let logical_input_schema = input.as_ref().schema(); @@ -496,7 +498,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - ctx_state, + session_state, ), physical_name(e), )) @@ -509,7 +511,7 @@ impl DefaultPhysicalPlanner { e, logical_input_schema, &physical_input_schema, - &ctx_state.execution_props, + &session_state.execution_props, ) }) .collect::>>()?; @@ -532,8 +534,8 @@ impl DefaultPhysicalPlanner { .any(|x| matches!(x, DataType::Dictionary(_, _))); let can_repartition = !groups.is_empty() - && ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_aggregations + && session_state.config.target_partitions > 1 + && session_state.config.repartition_aggregations && !contains_dict; let (initial_aggr, next_partition_mode): ( @@ -545,7 +547,7 @@ impl DefaultPhysicalPlanner { initial_aggr, Partitioning::Hash( final_group.clone(), - ctx_state.config.target_partitions, + session_state.config.target_partitions, ), )?); // Combine hash aggregates within the partition @@ -569,7 +571,7 @@ impl DefaultPhysicalPlanner { )?) ) } LogicalPlan::Projection(Projection { input, expr, .. }) => { - let input_exec = self.create_initial_plan(input, ctx_state).await?; + let input_exec = self.create_initial_plan(input, session_state).await?; let input_schema = input.as_ref().schema(); let physical_exprs = expr @@ -608,7 +610,7 @@ impl DefaultPhysicalPlanner { e, input_schema, &input_exec.schema(), - ctx_state, + session_state, ), physical_name, )) @@ -623,7 +625,7 @@ impl DefaultPhysicalPlanner { LogicalPlan::Filter(Filter { input, predicate, .. }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); @@ -631,13 +633,13 @@ impl DefaultPhysicalPlanner { predicate, input_dfschema, &input_schema, - ctx_state, + session_state, )?; Ok(Arc::new(FilterExec::try_new(runtime_expr, physical_input)?) ) } LogicalPlan::Union(Union { inputs, .. }) => { let physical_plans = futures::stream::iter(inputs) - .then(|lp| self.create_initial_plan(lp, ctx_state)) + .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; Ok(Arc::new(UnionExec::new(physical_plans)) ) @@ -646,7 +648,7 @@ impl DefaultPhysicalPlanner { input, partitioning_scheme, }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.schema(); let input_dfschema = input.as_ref().schema(); let physical_partitioning = match partitioning_scheme { @@ -661,7 +663,7 @@ impl DefaultPhysicalPlanner { e, input_dfschema, &input_schema, - ctx_state, + session_state, ) }) .collect::>>()?; @@ -674,7 +676,7 @@ impl DefaultPhysicalPlanner { )?) ) } LogicalPlan::Sort(Sort { expr, input, .. }) => { - let physical_input = self.create_initial_plan(input, ctx_state).await?; + let physical_input = self.create_initial_plan(input, session_state).await?; let input_schema = physical_input.as_ref().schema(); let input_dfschema = input.as_ref().schema(); let sort_expr = expr @@ -692,7 +694,7 @@ impl DefaultPhysicalPlanner { descending: !*asc, nulls_first: *nulls_first, }, - &ctx_state.execution_props, + &session_state.execution_props, ), _ => Err(DataFusionError::Plan( "Sort only accepts sort expressions".to_string(), @@ -710,9 +712,9 @@ impl DefaultPhysicalPlanner { .. }) => { let left_df_schema = left.schema(); - let physical_left = self.create_initial_plan(left, ctx_state).await?; + let physical_left = self.create_initial_plan(left, session_state).await?; let right_df_schema = right.schema(); - let physical_right = self.create_initial_plan(right, ctx_state).await?; + let physical_right = self.create_initial_plan(right, session_state).await?; let join_on = keys .iter() .map(|(l, r)| { @@ -723,8 +725,8 @@ impl DefaultPhysicalPlanner { }) .collect::>()?; - if ctx_state.config.target_partitions > 1 - && ctx_state.config.repartition_joins + if session_state.config.target_partitions > 1 + && session_state.config.repartition_joins { let (left_expr, right_expr) = join_on .iter() @@ -742,14 +744,14 @@ impl DefaultPhysicalPlanner { physical_left, Partitioning::Hash( left_expr, - ctx_state.config.target_partitions, + session_state.config.target_partitions, ), )?), Arc::new(RepartitionExec::try_new( physical_right, Partitioning::Hash( right_expr, - ctx_state.config.target_partitions, + session_state.config.target_partitions, ), )?), join_on, @@ -769,8 +771,8 @@ impl DefaultPhysicalPlanner { } } LogicalPlan::CrossJoin(CrossJoin { left, right, .. }) => { - let left = self.create_initial_plan(left, ctx_state).await?; - let right = self.create_initial_plan(right, ctx_state).await?; + let left = self.create_initial_plan(left, session_state).await?; + let right = self.create_initial_plan(right, session_state).await?; Ok(Arc::new(CrossJoinExec::try_new(left, right)?)) } LogicalPlan::EmptyRelation(EmptyRelation { @@ -782,7 +784,7 @@ impl DefaultPhysicalPlanner { ))), LogicalPlan::Limit(Limit { input, n, .. }) => { let limit = *n; - let input = self.create_initial_plan(input, ctx_state).await?; + let input = self.create_initial_plan(input, session_state).await?; // GlobalLimitExec requires a single partition for input let input = if input.output_partitioning().partition_count() == 1 { @@ -815,13 +817,13 @@ impl DefaultPhysicalPlanner { "Unsupported logical plan: Explain must be root of the plan".to_string(), )), LogicalPlan::Analyze(a) => { - let input = self.create_initial_plan(&a.input, ctx_state).await?; + let input = self.create_initial_plan(&a.input, session_state).await?; let schema = SchemaRef::new((*a.schema).clone().into()); Ok(Arc::new(AnalyzeExec::new(a.verbose, input, schema))) } LogicalPlan::Extension(e) => { let physical_inputs = futures::stream::iter(e.node.inputs()) - .then(|lp| self.create_initial_plan(lp, ctx_state)) + .then(|lp| self.create_initial_plan(lp, session_state)) .try_collect::>() .await?; @@ -836,7 +838,7 @@ impl DefaultPhysicalPlanner { e.node.as_ref(), &e.node.inputs(), &physical_inputs, - ctx_state, + session_state, ) } }, @@ -1351,7 +1353,7 @@ impl DefaultPhysicalPlanner { async fn handle_explain( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result>> { if let LogicalPlan::Explain(e) = logical_plan { use PlanType::*; @@ -1359,16 +1361,19 @@ impl DefaultPhysicalPlanner { stringified_plans.push(e.plan.to_stringified(FinalLogicalPlan)); - let input = self.create_initial_plan(e.plan.as_ref(), ctx_state).await?; + let input = self + .create_initial_plan(e.plan.as_ref(), session_state) + .await?; stringified_plans .push(displayable(input.as_ref()).to_stringified(InitialPhysicalPlan)); - let input = self.optimize_internal(input, ctx_state, |plan, optimizer| { - let optimizer_name = optimizer.name().to_string(); - let plan_type = OptimizedPhysicalPlan { optimizer_name }; - stringified_plans.push(displayable(plan).to_stringified(plan_type)); - })?; + let input = + self.optimize_internal(input, session_state, |plan, optimizer| { + let optimizer_name = optimizer.name().to_string(); + let plan_type = OptimizedPhysicalPlan { optimizer_name }; + stringified_plans.push(displayable(plan).to_stringified(plan_type)); + })?; stringified_plans .push(displayable(input.as_ref()).to_stringified(FinalPhysicalPlan)); @@ -1388,13 +1393,13 @@ impl DefaultPhysicalPlanner { fn optimize_internal( &self, plan: Arc, - ctx_state: &ExecutionContextState, + session_state: &SessionState, mut observer: F, ) -> Result> where F: FnMut(&dyn ExecutionPlan, &dyn PhysicalOptimizerRule), { - let optimizers = &ctx_state.config.physical_optimizers; + let optimizers = &session_state.config.physical_optimizers; debug!( "Input physical plan:\n{}\n", displayable(plan.as_ref()).indent() @@ -1403,7 +1408,7 @@ impl DefaultPhysicalPlanner { let mut new_plan = plan; for optimizer in optimizers { - new_plan = optimizer.optimize(new_plan, &ctx_state.config)?; + new_plan = optimizer.optimize(new_plan, &session_state.config)?; observer(new_plan.as_ref(), optimizer.as_ref()) } debug!( @@ -1428,8 +1433,8 @@ fn tuple_err(value: (Result, Result)) -> Result<(T, R)> { mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; + use crate::execution::context::TaskContext; use crate::execution::options::CsvReadOptions; - use crate::execution::runtime_env::RuntimeEnv; use crate::logical_plan::plan::Extension; use crate::physical_plan::{ expressions, DisplayFormatType, Partitioning, Statistics, @@ -1448,15 +1453,17 @@ mod tests { use std::convert::TryFrom; use std::{any::Any, fmt}; - fn make_ctx_state() -> ExecutionContextState { - ExecutionContextState::new() + fn make_session_state() -> SessionState { + SessionState::new() } async fn plan(logical_plan: &LogicalPlan) -> Result> { - let mut ctx_state = make_ctx_state(); - ctx_state.config.target_partitions = 4; + let mut session_state = make_session_state(); + session_state.config.target_partitions = 4; let planner = DefaultPhysicalPlanner::default(); - planner.create_physical_plan(logical_plan, &ctx_state).await + planner + .create_physical_plan(logical_plan, &session_state) + .await } #[tokio::test] @@ -1502,7 +1509,7 @@ mod tests { &col("a").not(), &dfschema, &schema, - &make_ctx_state(), + &make_session_state(), )?; let expected = expressions::not(expressions::col("a", &schema)?, &schema)?; @@ -1581,13 +1588,13 @@ mod tests { #[tokio::test] async fn default_extension_planner() { - let ctx_state = make_ctx_state(); + let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::default(); let logical_plan = LogicalPlan::Extension(Extension { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner - .create_physical_plan(&logical_plan, &ctx_state) + .create_physical_plan(&logical_plan, &session_state) .await; let expected_error = @@ -1607,7 +1614,7 @@ mod tests { async fn bad_extension_planner() { // Test that creating an execution plan whose schema doesn't // match the logical plan's schema generates an error. - let ctx_state = make_ctx_state(); + let session_state = make_session_state(); let planner = DefaultPhysicalPlanner::with_extension_planners(vec![Arc::new( BadExtensionPlanner {}, )]); @@ -1616,7 +1623,7 @@ mod tests { node: Arc::new(NoOpExtensionNode::default()), }); let plan = planner - .create_physical_plan(&logical_plan, &ctx_state) + .create_physical_plan(&logical_plan, &session_state) .await; let expected_error: &str = "Error during planning: \ @@ -1905,7 +1912,7 @@ mod tests { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("NoOpExecutionPlan::execute"); } @@ -1939,7 +1946,7 @@ mod tests { _node: &dyn UserDefinedLogicalNode, _logical_inputs: &[&LogicalPlan], _physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result>> { Ok(Some(Arc::new(NoOpExecutionPlan { schema: SchemaRef::new(Schema::new(vec![Field::new( diff --git a/datafusion/src/physical_plan/projection.rs b/datafusion/src/physical_plan/projection.rs index 5940b64957c14..a9bb8481a9d1b 100644 --- a/datafusion/src/physical_plan/projection.rs +++ b/datafusion/src/physical_plan/projection.rs @@ -37,7 +37,7 @@ use arrow::record_batch::RecordBatch; use super::expressions::{Column, PhysicalSortExpr}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream, Statistics}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use async_trait::async_trait; use futures::stream::Stream; use futures::stream::StreamExt; @@ -153,12 +153,12 @@ impl ExecutionPlan for ProjectionExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { Ok(Box::pin(ProjectionStream { schema: self.schema.clone(), expr: self.expr.iter().map(|x| x.0.clone()).collect(), - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, baseline_metrics: BaselineMetrics::new(&self.metrics, partition), })) } @@ -303,6 +303,7 @@ mod tests { use crate::datasource::object_store::local::LocalFileSystem; use crate::physical_plan::expressions::{self, col}; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; + use crate::prelude::SessionContext; use crate::scalar::ScalarValue; use crate::test::{self}; use crate::test_util; @@ -310,7 +311,8 @@ mod tests { #[tokio::test] async fn project_first_column() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -346,7 +348,7 @@ mod tests { let mut row_count = 0; for partition in 0..projection.output_partitioning().partition_count() { partition_count += 1; - let stream = projection.execute(partition, runtime.clone()).await?; + let stream = projection.execute(partition, task_ctx.clone()).await?; row_count += stream .map(|batch| { diff --git a/datafusion/src/physical_plan/repartition.rs b/datafusion/src/physical_plan/repartition.rs index 55328c40c951d..18cfb4cd3120f 100644 --- a/datafusion/src/physical_plan/repartition.rs +++ b/datafusion/src/physical_plan/repartition.rs @@ -37,7 +37,7 @@ use super::metrics::{self, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet}; use super::{RecordBatchStream, SendableRecordBatchStream}; use async_trait::async_trait; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use futures::stream::Stream; use futures::StreamExt; use hashbrown::HashMap; @@ -177,7 +177,7 @@ impl ExecutionPlan for RepartitionExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { // lock mutexes let mut state = self.state.lock().await; @@ -221,7 +221,7 @@ impl ExecutionPlan for RepartitionExec { txs.clone(), self.partitioning.clone(), r_metrics, - runtime.clone(), + context.clone(), )); // In a separate task, wait for each input to be done @@ -300,13 +300,13 @@ impl RepartitionExec { mut txs: HashMap>>>, partitioning: Partitioning, r_metrics: RepartitionMetrics, - runtime: Arc, + context: Arc, ) -> Result<()> { let num_output_partitions = txs.len(); // execute the child operator let timer = r_metrics.fetch_time.timer(); - let mut stream = input.execute(i, runtime).await?; + let mut stream = input.execute(i, context).await?; timer.done(); let mut counter = 0; @@ -503,6 +503,7 @@ impl RecordBatchStream for RepartitionStream { mod tests { use super::*; use crate::from_slice::FromSlice; + use crate::prelude::SessionContext; use crate::test::create_vec_batches; use crate::{ assert_batches_sorted_eq, @@ -616,7 +617,8 @@ mod tests { input_partitions: Vec>, partitioning: Partitioning, ) -> Result>> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); // create physical plan let exec = MemoryExec::try_new(&input_partitions, schema.clone(), None)?; let exec = RepartitionExec::try_new(Arc::new(exec), partitioning)?; @@ -625,7 +627,7 @@ mod tests { let mut output_partitions = vec![]; for i in 0..exec.partitioning.partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, runtime.clone()).await?; + let mut stream = exec.execute(i, task_ctx.clone()).await?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -665,7 +667,8 @@ mod tests { #[tokio::test] async fn unsupported_partitioning() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); // have to send at least one batch through to provoke error let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", @@ -680,7 +683,7 @@ mod tests { // returned and no results produced let partitioning = Partitioning::UnknownPartitioning(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream = exec.execute(0, runtime).await.unwrap(); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -700,14 +703,15 @@ mod tests { // This generates an error on a call to execute. The error // should be returned and no results produced. - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let input = ErrorExec::new(); let partitioning = Partitioning::RoundRobinBatch(1); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0, runtime).await.unwrap(); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -723,7 +727,8 @@ mod tests { #[tokio::test] async fn repartition_with_error_in_stream() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let batch = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef, @@ -741,7 +746,7 @@ mod tests { // Note: this should pass (the stream can be created) but the // error when the input is executed should get passed back - let output_stream = exec.execute(0, runtime).await.unwrap(); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); // Expect that an error is returned let result_string = crate::physical_plan::common::collect(output_stream) @@ -757,7 +762,8 @@ mod tests { #[tokio::test] async fn repartition_with_delayed_stream() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let batch1 = RecordBatch::try_from_iter(vec![( "my_awesome_field", Arc::new(StringArray::from_slice(&["foo", "bar"])) as ArrayRef, @@ -792,7 +798,7 @@ mod tests { assert_batches_sorted_eq!(&expected, &expected_batches); - let output_stream = exec.execute(0, runtime).await.unwrap(); + let output_stream = exec.execute(0, task_ctx).await.unwrap(); let batches = crate::physical_plan::common::collect(output_stream) .await .unwrap(); @@ -802,7 +808,8 @@ mod tests { #[tokio::test] async fn robin_repartition_with_dropping_output_stream() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let partitioning = Partitioning::RoundRobinBatch(2); // The barrier exec waits to be pinged // requires the input to wait at least once) @@ -811,8 +818,8 @@ mod tests { // partition into two output streams let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let output_stream0 = exec.execute(0, task_ctx.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, task_ctx.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced @@ -845,7 +852,8 @@ mod tests { // wiht different compilers, we will compare the same execution with // and without droping the output stream. async fn hash_repartition_with_dropping_output_stream() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let partitioning = Partitioning::Hash( vec![Arc::new(crate::physical_plan::expressions::Column::new( "my_awesome_field", @@ -857,7 +865,7 @@ mod tests { // We first collect the results without droping the output stream. let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning.clone()).unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, task_ctx.clone()).await.unwrap(); input.wait().await; let batches_without_drop = crate::physical_plan::common::collect(output_stream1) .await @@ -877,8 +885,8 @@ mod tests { // Now do the same but dropping the stream before waiting for the barrier let input = Arc::new(make_barrier_exec()); let exec = RepartitionExec::try_new(input.clone(), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let output_stream0 = exec.execute(0, task_ctx.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, task_ctx.clone()).await.unwrap(); // now, purposely drop output stream 0 // *before* any outputs are produced std::mem::drop(output_stream0); @@ -943,7 +951,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -954,7 +963,7 @@ mod tests { Partitioning::UnknownPartitioning(1), )?); - let fut = collect(repartition_exec, runtime); + let fut = collect(repartition_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -966,7 +975,8 @@ mod tests { #[tokio::test] async fn hash_repartition_avoid_empty_batch() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let batch = RecordBatch::try_from_iter(vec![( "a", Arc::new(StringArray::from_slice(&["foo"])) as ArrayRef, @@ -981,11 +991,11 @@ mod tests { let schema = batch.schema(); let input = MockExec::new(vec![Ok(batch)], schema); let exec = RepartitionExec::try_new(Arc::new(input), partitioning).unwrap(); - let output_stream0 = exec.execute(0, runtime.clone()).await.unwrap(); + let output_stream0 = exec.execute(0, task_ctx.clone()).await.unwrap(); let batch0 = crate::physical_plan::common::collect(output_stream0) .await .unwrap(); - let output_stream1 = exec.execute(1, runtime.clone()).await.unwrap(); + let output_stream1 = exec.execute(1, task_ctx.clone()).await.unwrap(); let batch1 = crate::physical_plan::common::collect(output_stream1) .await .unwrap(); diff --git a/datafusion/src/physical_plan/sorts/sort.rs b/datafusion/src/physical_plan/sorts/sort.rs index 1428e1627d8f8..b2bf604665a08 100644 --- a/datafusion/src/physical_plan/sorts/sort.rs +++ b/datafusion/src/physical_plan/sorts/sort.rs @@ -20,6 +20,7 @@ //! but spills to disk if needed. use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; use crate::execution::memory_manager::{ human_readable_size, ConsumerType, MemoryConsumer, MemoryConsumerId, MemoryManager, }; @@ -476,7 +477,7 @@ impl ExecutionPlan for SortExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if !self.preserve_partitioning { if 0 != partition { @@ -494,14 +495,14 @@ impl ExecutionPlan for SortExec { } } - let input = self.input.execute(partition, runtime.clone()).await?; + let input = self.input.execute(partition, context.clone()).await?; do_sort( input, partition, self.expr.clone(), self.metrics_set.clone(), - runtime, + context, ) .await } @@ -568,7 +569,7 @@ async fn do_sort( partition_id: usize, expr: Vec, metrics_set: CompositeMetricsSet, - runtime: Arc, + context: Arc, ) -> Result { let schema = input.schema(); let sorter = ExternalSorter::new( @@ -576,9 +577,9 @@ async fn do_sort( schema.clone(), expr, metrics_set, - runtime.clone(), + context.runtime.clone(), ); - runtime.register_requester(sorter.id()); + context.runtime.register_requester(sorter.id()); while let Some(batch) = input.next().await { let batch = batch?; sorter.insert_batch(batch).await?; @@ -590,7 +591,7 @@ async fn do_sort( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::context::ExecutionConfig; + use crate::execution::context::SessionConfig; use crate::physical_plan::coalesce_partitions::CoalescePartitionsExec; use crate::physical_plan::expressions::col; use crate::physical_plan::memory::MemoryExec; @@ -598,6 +599,7 @@ mod tests { collect, file_format::{CsvExec, FileScanConfig}, }; + use crate::prelude::SessionContext; use crate::test; use crate::test::assert_is_pending; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; @@ -610,7 +612,8 @@ mod tests { #[tokio::test] async fn test_in_mem_sort() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = @@ -651,7 +654,7 @@ mod tests { Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), )?); - let result = collect(sort_exec, runtime).await?; + let result = collect(sort_exec, task_ctx).await?; assert_eq!(result.len(), 1); @@ -675,8 +678,8 @@ mod tests { #[tokio::test] async fn test_sort_spill() -> Result<()> { // trigger spill there will be 4 batches with 5.5KB for each - let config = ExecutionConfig::new().with_memory_limit(12288, 1.0)?; - let runtime = Arc::new(RuntimeEnv::new(config.runtime)?); + let config = SessionConfig::new().with_memory_limit(12288, 1.0)?; + let session_ctx = SessionContext::with_config(config); let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -718,7 +721,8 @@ mod tests { Arc::new(CoalescePartitionsExec::new(Arc::new(csv))), )?); - let result = collect(sort_exec.clone(), runtime).await?; + let task_ctx = session_ctx.task_ctx(); + let result = collect(sort_exec.clone(), task_ctx).await?; assert_eq!(result.len(), 1); @@ -749,7 +753,8 @@ mod tests { #[tokio::test] async fn test_sort_metadata() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let field_metadata: BTreeMap = vec![("foo".to_string(), "bar".to_string())] .into_iter() @@ -779,7 +784,7 @@ mod tests { input, )?); - let result: Vec = collect(sort_exec, runtime).await?; + let result: Vec = collect(sort_exec, task_ctx).await?; let expected_data: ArrayRef = Arc::new(vec![1, 2, 3].into_iter().map(Some).collect::()); @@ -801,7 +806,8 @@ mod tests { #[tokio::test] async fn test_lex_sort_by_float() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Float32, true), Field::new("b", DataType::Float64, true), @@ -857,7 +863,7 @@ mod tests { assert_eq!(DataType::Float32, *sort_exec.schema().field(0).data_type()); assert_eq!(DataType::Float64, *sort_exec.schema().field(1).data_type()); - let result: Vec = collect(sort_exec.clone(), runtime).await?; + let result: Vec = collect(sort_exec.clone(), task_ctx).await?; let metrics = sort_exec.metrics().unwrap(); assert!(metrics.elapsed_compute().unwrap() > 0); assert_eq!(metrics.output_rows().unwrap(), 8); @@ -906,7 +912,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -920,7 +927,7 @@ mod tests { blocking_exec, )?); - let fut = collect(sort_exec, runtime); + let fut = collect(sort_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs index 780e2cc676593..b04af04ba5bfd 100644 --- a/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs +++ b/datafusion/src/physical_plan/sorts/sort_preserving_merge.rs @@ -42,6 +42,7 @@ use futures::stream::FusedStream; use futures::{Stream, StreamExt}; use crate::error::{DataFusionError, Result}; +use crate::execution::context::TaskContext; use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::sorts::{RowIndex, SortKeyCursor, SortedStream, StreamWrapper}; use crate::physical_plan::{ @@ -158,7 +159,7 @@ impl ExecutionPlan for SortPreservingMergeExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -177,7 +178,7 @@ impl ExecutionPlan for SortPreservingMergeExec { )), 1 => { // bypass if there is only one partition to merge (no metrics in this case either) - self.input.execute(0, runtime).await + self.input.execute(0, context).await } _ => { let (receivers, join_handles) = (0..input_partitions) @@ -188,7 +189,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self.input.clone(), sender, part_i, - runtime.clone(), + context.clone(), ); (receiver, join_handle) }) @@ -200,7 +201,7 @@ impl ExecutionPlan for SortPreservingMergeExec { self.schema(), &self.expr, tracking_metrics, - runtime, + context.runtime.clone(), ))) } } @@ -623,14 +624,15 @@ mod tests { use crate::{assert_batches_eq, test_util}; use super::*; - use crate::execution::runtime_env::RuntimeConfig; + use crate::prelude::{SessionConfig, SessionContext}; use arrow::datatypes::{DataType, Field, Schema}; use futures::{FutureExt, SinkExt}; use tokio_stream::StreamExt; #[tokio::test] async fn test_merge_interleave() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -671,14 +673,15 @@ mod tests { "| 3 | j | 1970-01-01 00:00:00.000000008 |", "+----+---+-------------------------------+", ], - runtime, + task_ctx, ) .await; } #[tokio::test] async fn test_merge_some_overlap() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -719,14 +722,15 @@ mod tests { "| 110 | g | 1970-01-01 00:00:00.000000006 |", "+-----+---+-------------------------------+", ], - runtime, + task_ctx, ) .await; } #[tokio::test] async fn test_merge_no_overlap() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -767,14 +771,15 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000006 |", "+----+---+-------------------------------+", ], - runtime, + task_ctx, ) .await; } #[tokio::test] async fn test_merge_three_partitions() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ Some("a"), @@ -832,7 +837,7 @@ mod tests { "| 30 | j | 1970-01-01 00:00:00.000000060 |", "+-----+---+-------------------------------+", ], - runtime, + task_ctx, ) .await; } @@ -840,7 +845,7 @@ mod tests { async fn _test_merge( partitions: &[Vec], exp: &[&str], - runtime: Arc, + context: Arc, ) { let schema = partitions[0][0].schema(); let sort = vec![ @@ -856,17 +861,17 @@ mod tests { let exec = MemoryExec::try_new(partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, context).await.unwrap(); assert_batches_eq!(exp, collected.as_slice()); } async fn sorted_merge( input: Arc, sort: Vec, - runtime: Arc, + context: Arc, ) -> RecordBatch { let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let mut result = collect(merge, runtime).await.unwrap(); + let mut result = collect(merge, context).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } @@ -874,28 +879,29 @@ mod tests { async fn partition_sort( input: Arc, sort: Vec, - runtime: Arc, + context: Arc, ) -> RecordBatch { let sort_exec = Arc::new(SortExec::new_with_partitioning(sort.clone(), input, true)); - sorted_merge(sort_exec, sort, runtime).await + sorted_merge(sort_exec, sort, context).await } async fn basic_sort( src: Arc, sort: Vec, - runtime: Arc, + context: Arc, ) -> RecordBatch { let merge = Arc::new(CoalescePartitionsExec::new(src)); let sort_exec = Arc::new(SortExec::try_new(sort, merge).unwrap()); - let mut result = collect(sort_exec, runtime).await.unwrap(); + let mut result = collect(sort_exec, context).await.unwrap(); assert_eq!(result.len(), 1); result.remove(0) } #[tokio::test] async fn test_partition_sort() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let partitions = 4; let (_, files) = @@ -937,8 +943,8 @@ mod tests { }, ]; - let basic = basic_sort(csv.clone(), sort.clone(), runtime.clone()).await; - let partition = partition_sort(csv, sort, runtime.clone()).await; + let basic = basic_sort(csv.clone(), sort.clone(), task_ctx.clone()).await; + let partition = partition_sort(csv, sort, task_ctx.clone()).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -981,7 +987,7 @@ mod tests { async fn sorted_partitioned_input( sort: Vec, sizes: &[usize], - runtime: Arc, + context: Arc, ) -> Arc { let schema = test_util::aggr_test_schema(); let partitions = 4; @@ -1002,7 +1008,7 @@ mod tests { b',', )); - let sorted = basic_sort(csv, sort, runtime).await; + let sorted = basic_sort(csv, sort, context).await; let split: Vec<_> = sizes.iter().map(|x| split_batch(&sorted, *x)).collect(); Arc::new(MemoryExec::try_new(&split, sorted.schema(), None).unwrap()) @@ -1010,7 +1016,8 @@ mod tests { #[tokio::test] async fn test_partition_sort_streaming_input() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let sort = vec![ // uint8 @@ -1036,9 +1043,9 @@ mod tests { ]; let input = - sorted_partitioned_input(sort.clone(), &[10, 3, 11], runtime.clone()).await; - let basic = basic_sort(input.clone(), sort.clone(), runtime.clone()).await; - let partition = sorted_merge(input, sort, runtime.clone()).await; + sorted_partitioned_input(sort.clone(), &[10, 3, 11], task_ctx.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), task_ctx.clone()).await; + let partition = sorted_merge(input, sort, task_ctx.clone()).await; assert_eq!(basic.num_rows(), 300); assert_eq!(partition.num_rows(), 300); @@ -1070,15 +1077,18 @@ mod tests { }, ]; - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let input = - sorted_partitioned_input(sort.clone(), &[10, 5, 13], runtime.clone()).await; - let basic = basic_sort(input.clone(), sort.clone(), runtime).await; + sorted_partitioned_input(sort.clone(), &[10, 5, 13], task_ctx.clone()).await; + let basic = basic_sort(input.clone(), sort.clone(), task_ctx).await; + + let session_ctx_bs_23 = + SessionContext::with_config(SessionConfig::new().with_batch_size(23)); - let runtime_bs_23 = - Arc::new(RuntimeEnv::new(RuntimeConfig::new().with_batch_size(23)).unwrap()); let merge = Arc::new(SortPreservingMergeExec::new(sort, input)); - let merged = collect(merge, runtime_bs_23).await.unwrap(); + let task_ctx = session_ctx_bs_23.task_ctx(); + let merged = collect(merge, task_ctx).await.unwrap(); assert_eq!(merged.len(), 14); @@ -1097,7 +1107,8 @@ mod tests { #[tokio::test] async fn test_nulls() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2, 7, 9, 3])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![ None, @@ -1152,7 +1163,7 @@ mod tests { let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); assert_batches_eq!( @@ -1178,7 +1189,8 @@ mod tests { #[tokio::test] async fn test_async() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let sort = vec![PhysicalSortExpr { expr: col("c12", &schema).unwrap(), @@ -1186,7 +1198,7 @@ mod tests { }]; let batches = - sorted_partitioned_input(sort.clone(), &[5, 7, 3], runtime.clone()).await; + sorted_partitioned_input(sort.clone(), &[5, 7, 3], task_ctx.clone()).await; let partition_count = batches.output_partitioning().partition_count(); let mut join_handles = Vec::with_capacity(partition_count); @@ -1194,7 +1206,7 @@ mod tests { for partition in 0..partition_count { let (mut sender, receiver) = mpsc::channel(1); - let mut stream = batches.execute(partition, runtime.clone()).await.unwrap(); + let mut stream = batches.execute(partition, task_ctx.clone()).await.unwrap(); let join_handle = tokio::spawn(async move { while let Some(batch) = stream.next().await { sender.send(batch).await.unwrap(); @@ -1216,7 +1228,7 @@ mod tests { batches.schema(), sort.as_slice(), tracking_metrics, - runtime.clone(), + task_ctx.runtime.clone(), ); let mut merged = common::collect(Box::pin(merge_stream)).await.unwrap(); @@ -1228,7 +1240,7 @@ mod tests { assert_eq!(merged.len(), 1); let merged = merged.remove(0); - let basic = basic_sort(batches, sort.clone(), runtime.clone()).await; + let basic = basic_sort(batches, sort.clone(), task_ctx.clone()).await; let basic = arrow::util::pretty::pretty_format_batches(&[basic]) .unwrap() @@ -1246,7 +1258,8 @@ mod tests { #[tokio::test] async fn test_merge_metrics() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let a: ArrayRef = Arc::new(Int32Array::from_slice(&[1, 2])); let b: ArrayRef = Arc::new(StringArray::from_iter(vec![Some("a"), Some("c")])); let b1 = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); @@ -1263,7 +1276,7 @@ mod tests { let exec = MemoryExec::try_new(&[vec![b1], vec![b2]], schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge.clone(), runtime).await.unwrap(); + let collected = collect(merge.clone(), task_ctx).await.unwrap(); let expected = vec![ "+----+---+", "| a | b |", @@ -1302,7 +1315,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -1316,7 +1330,7 @@ mod tests { blocking_exec, )); - let fut = collect(sort_preserving_merge_exec, runtime); + let fut = collect(sort_preserving_merge_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); @@ -1328,7 +1342,8 @@ mod tests { #[tokio::test] async fn test_stable_sort() { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); // Create record batches like: // batch_number |value @@ -1368,7 +1383,7 @@ mod tests { let exec = MemoryExec::try_new(&partitions, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let collected = collect(merge, runtime).await.unwrap(); + let collected = collect(merge, task_ctx).await.unwrap(); assert_eq!(collected.len(), 1); // Expect the data to be sorted first by "batch_number" (because diff --git a/datafusion/src/physical_plan/union.rs b/datafusion/src/physical_plan/union.rs index 48f7b280b80e8..bd040df2f1b93 100644 --- a/datafusion/src/physical_plan/union.rs +++ b/datafusion/src/physical_plan/union.rs @@ -32,7 +32,7 @@ use super::{ ColumnStatistics, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, }; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::{ error::Result, physical_plan::{expressions, metrics::BaselineMetrics}, @@ -104,7 +104,7 @@ impl ExecutionPlan for UnionExec { async fn execute( &self, mut partition: usize, - runtime: Arc, + context: Arc, ) -> Result { let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); // record the tiny amount of work done in this function so @@ -116,7 +116,7 @@ impl ExecutionPlan for UnionExec { for input in self.inputs.iter() { // Calculate whether partition belongs to the current partition if partition < input.output_partitioning().partition_count() { - let stream = input.execute(partition, runtime.clone()).await?; + let stream = input.execute(partition, context.clone()).await?; return Ok(Box::pin(ObservedStream::new(stream, baseline_metrics))); } else { partition -= input.output_partitioning().partition_count(); @@ -237,6 +237,7 @@ mod tests { use crate::datasource::object_store::{local::LocalFileSystem, ObjectStore}; use crate::{test, test_util}; + use crate::prelude::SessionContext; use crate::{ physical_plan::{ collect, @@ -248,7 +249,8 @@ mod tests { #[tokio::test] async fn test_union_partitions() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = test_util::aggr_test_schema(); let fs: Arc = Arc::new(LocalFileSystem {}); @@ -289,7 +291,7 @@ mod tests { // Should have 9 partitions and 9 output batches assert_eq!(union_exec.output_partitioning().partition_count(), 9); - let result: Vec = collect(union_exec, runtime).await?; + let result: Vec = collect(union_exec, task_ctx).await?; assert_eq!(result.len(), 9); Ok(()) diff --git a/datafusion/src/physical_plan/values.rs b/datafusion/src/physical_plan/values.rs index c65082ef0677f..8cc448df5b98d 100644 --- a/datafusion/src/physical_plan/values.rs +++ b/datafusion/src/physical_plan/values.rs @@ -20,7 +20,7 @@ use super::expressions::PhysicalSortExpr; use super::{common, SendableRecordBatchStream, Statistics}; use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::{ memory::MemoryStream, ColumnarValue, DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalExpr, @@ -146,7 +146,7 @@ impl ExecutionPlan for ValuesExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { // GlobalLimitExec has a single output partition if 0 != partition { diff --git a/datafusion/src/physical_plan/windows/mod.rs b/datafusion/src/physical_plan/windows/mod.rs index e833c57c5b5ee..03e7342e938bc 100644 --- a/datafusion/src/physical_plan/windows/mod.rs +++ b/datafusion/src/physical_plan/windows/mod.rs @@ -153,11 +153,11 @@ fn create_built_in_window_expr( mod tests { use super::*; use crate::datasource::object_store::local::LocalFileSystem; - use crate::execution::runtime_env::RuntimeEnv; use crate::physical_plan::aggregates::AggregateFunction; use crate::physical_plan::expressions::col; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::{collect, Statistics}; + use crate::prelude::SessionContext; use crate::test::exec::{assert_strong_count_converges_to_zero, BlockingExec}; use crate::test::{self, assert_is_pending}; use crate::test_util::{self, aggr_test_schema}; @@ -190,7 +190,8 @@ mod tests { #[tokio::test] async fn window_function() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let (input, schema) = create_test_schema(1)?; let window_exec = Arc::new(WindowAggExec::try_new( @@ -227,7 +228,7 @@ mod tests { schema.clone(), )?); - let result: Vec = collect(window_exec, runtime).await?; + let result: Vec = collect(window_exec, task_ctx).await?; assert_eq!(result.len(), 1); let columns = result[0].columns(); @@ -251,7 +252,8 @@ mod tests { #[tokio::test] async fn test_drop_cancel() -> Result<()> { - let runtime = Arc::new(RuntimeEnv::default()); + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Float32, true)])); @@ -271,7 +273,7 @@ mod tests { schema, )?); - let fut = collect(window_agg_exec, runtime); + let fut = collect(window_agg_exec, task_ctx); let mut fut = fut.boxed(); assert_is_pending(&mut fut); diff --git a/datafusion/src/physical_plan/windows/window_agg_exec.rs b/datafusion/src/physical_plan/windows/window_agg_exec.rs index 163868d078386..f59bc910bffac 100644 --- a/datafusion/src/physical_plan/windows/window_agg_exec.rs +++ b/datafusion/src/physical_plan/windows/window_agg_exec.rs @@ -18,7 +18,7 @@ //! Stream and channel implementations for window function expressions. use crate::error::{DataFusionError, Result}; -use crate::execution::runtime_env::RuntimeEnv; +use crate::execution::context::TaskContext; use crate::physical_plan::common::AbortOnDropSingle; use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::metrics::{ @@ -158,9 +158,9 @@ impl ExecutionPlan for WindowAggExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { - let input = self.input.execute(partition, runtime).await?; + let input = self.input.execute(partition, context).await?; let stream = Box::pin(WindowAggStream::new( self.schema.clone(), self.window_expr.clone(), diff --git a/datafusion/src/prelude.rs b/datafusion/src/prelude.rs index 0aff006c7896d..e40693e715662 100644 --- a/datafusion/src/prelude.rs +++ b/datafusion/src/prelude.rs @@ -26,7 +26,7 @@ //! ``` pub use crate::dataframe::DataFrame; -pub use crate::execution::context::{ExecutionConfig, ExecutionContext}; +pub use crate::execution::context::{SessionConfig, SessionContext}; pub use crate::execution::options::AvroReadOptions; pub use crate::execution::options::{CsvReadOptions, NdJsonReadOptions}; pub use crate::logical_plan::{ diff --git a/datafusion/src/test/exec.rs b/datafusion/src/test/exec.rs index 5a6b27865d133..41d3c55ead73e 100644 --- a/datafusion/src/test/exec.rs +++ b/datafusion/src/test/exec.rs @@ -33,6 +33,8 @@ use arrow::{ }; use futures::Stream; +use crate::execution::context::TaskContext; +use crate::physical_plan::expressions::PhysicalSortExpr; use crate::physical_plan::{ common, DisplayFormatType, ExecutionPlan, Partitioning, RecordBatchStream, SendableRecordBatchStream, Statistics, @@ -41,9 +43,6 @@ use crate::{ error::{DataFusionError, Result}, physical_plan::stream::RecordBatchReceiverStream, }; -use crate::{ - execution::runtime_env::RuntimeEnv, physical_plan::expressions::PhysicalSortExpr, -}; /// Index into the data that has been returned so far #[derive(Debug, Default, Clone)] @@ -172,7 +171,7 @@ impl ExecutionPlan for MockExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert_eq!(partition, 0); @@ -311,7 +310,7 @@ impl ExecutionPlan for BarrierExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { assert!(partition < self.data.len()); @@ -412,7 +411,7 @@ impl ExecutionPlan for ErrorExec { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Err(DataFusionError::Internal(format!( "ErrorExec, unsurprisingly, errored in partition {}", @@ -497,7 +496,7 @@ impl ExecutionPlan for StatisticsExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -595,7 +594,7 @@ impl ExecutionPlan for BlockingExec { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(BlockingStream { schema: Arc::clone(&self.schema), diff --git a/datafusion/tests/custom_sources.rs b/datafusion/tests/custom_sources.rs index 926a017f14af4..c2ff27d4f0485 100644 --- a/datafusion/tests/custom_sources.rs +++ b/datafusion/tests/custom_sources.rs @@ -30,7 +30,7 @@ use datafusion::{ physical_plan::DisplayFormatType, }; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; @@ -46,7 +46,6 @@ use std::sync::Arc; use std::task::{Context, Poll}; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; use datafusion::logical_plan::plan::Projection; //// Custom source dataframe tests //// @@ -135,7 +134,7 @@ impl ExecutionPlan for CustomExecutionPlan { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { Ok(Box::pin(TestCustomRecordBatchStream { nb_batch: 1 })) } @@ -211,7 +210,7 @@ impl TableProvider for CustomTableProvider { #[tokio::test] async fn custom_source_dataframe() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table = ctx.read_table(Arc::new(CustomTableProvider))?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -246,8 +245,8 @@ async fn custom_source_dataframe() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let batches = collect(physical_plan, task_ctx).await?; let origin_rec_batch = TEST_CUSTOM_RECORD_BATCH!()?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); @@ -258,7 +257,7 @@ async fn custom_source_dataframe() -> Result<()> { #[tokio::test] async fn optimizers_catch_all_statistics() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(CustomTableProvider)) .unwrap(); @@ -293,8 +292,8 @@ async fn optimizers_catch_all_statistics() { ) .unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let actual = collect(physical_plan, runtime).await.unwrap(); + let task_ctx = ctx.task_ctx(); + let actual = collect(physical_plan, task_ctx).await.unwrap(); assert_eq!(actual.len(), 1); assert_eq!(format!("{:?}", actual[0]), format!("{:?}", expected)); diff --git a/datafusion/tests/dataframe.rs b/datafusion/tests/dataframe.rs index 116315e9b9b27..3f106be73b4b6 100644 --- a/datafusion/tests/dataframe.rs +++ b/datafusion/tests/dataframe.rs @@ -25,7 +25,7 @@ use std::sync::Arc; use datafusion::assert_batches_eq; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use datafusion::logical_plan::{col, Expr}; use datafusion::{datasource::MemTable, prelude::JoinType}; use datafusion_expr::lit; @@ -58,7 +58,7 @@ async fn join() -> Result<()> { ], )?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table1 = MemTable::try_new(schema1, vec![vec![batch1]])?; let table2 = MemTable::try_new(schema2, vec![vec![batch2]])?; @@ -96,7 +96,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); @@ -132,7 +132,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]]).unwrap(); ctx.register_table("t", Arc::new(provider)).unwrap(); diff --git a/datafusion/tests/dataframe_functions.rs b/datafusion/tests/dataframe_functions.rs index 1f55af4513db6..8b7c5d8944167 100644 --- a/datafusion/tests/dataframe_functions.rs +++ b/datafusion/tests/dataframe_functions.rs @@ -31,7 +31,7 @@ use datafusion::error::Result; // use datafusion::logical_plan::Expr; use datafusion::prelude::*; -use datafusion::execution::context::ExecutionContext; +use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; @@ -55,7 +55,7 @@ fn create_test_table() -> Result> { ], )?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table = MemTable::try_new(schema, vec![vec![batch]])?; diff --git a/datafusion/tests/merge_fuzz.rs b/datafusion/tests/merge_fuzz.rs index d874ec507c49a..31ccc679cef29 100644 --- a/datafusion/tests/merge_fuzz.rs +++ b/datafusion/tests/merge_fuzz.rs @@ -23,15 +23,13 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; -use datafusion::{ - execution::runtime_env::{RuntimeConfig, RuntimeEnv}, - physical_plan::{ - collect, - expressions::{col, PhysicalSortExpr}, - memory::MemoryExec, - sorts::sort_preserving_merge::SortPreservingMergeExec, - }, +use datafusion::physical_plan::{ + collect, + expressions::{col, PhysicalSortExpr}, + memory::MemoryExec, + sorts::sort_preserving_merge::SortPreservingMergeExec, }; +use datafusion::prelude::{SessionConfig, SessionContext}; use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::{prelude::StdRng, Rng, SeedableRng}; @@ -120,10 +118,10 @@ async fn run_merge_test(input: Vec>) { let exec = MemoryExec::try_new(&input, schema, None).unwrap(); let merge = Arc::new(SortPreservingMergeExec::new(sort, Arc::new(exec))); - let runtime_config = RuntimeConfig::new().with_batch_size(batch_size); - - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - let collected = collect(merge, runtime).await.unwrap(); + let session_config = SessionConfig::new().with_batch_size(batch_size); + let ctx = SessionContext::with_config(session_config); + let task_ctx = ctx.task_ctx(); + let collected = collect(merge, task_ctx).await.unwrap(); // verify the output batch size: all batches except the last // should contain `batch_size` rows diff --git a/datafusion/tests/order_spill_fuzz.rs b/datafusion/tests/order_spill_fuzz.rs index b1586f06c02c8..53dadb7f17438 100644 --- a/datafusion/tests/order_spill_fuzz.rs +++ b/datafusion/tests/order_spill_fuzz.rs @@ -23,11 +23,12 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion::execution::memory_manager::MemoryManagerConfig; -use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; +use datafusion::execution::runtime_env::RuntimeConfig; use datafusion::physical_plan::expressions::{col, PhysicalSortExpr}; use datafusion::physical_plan::memory::MemoryExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; use fuzz_utils::{add_empty_batches, batches_to_vec, partitions_to_sorted_vec}; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; @@ -77,8 +78,11 @@ async fn run_sort(pool_size: usize, size_spill: Vec<(usize, bool)>) { let runtime_config = RuntimeConfig::new().with_memory_manager( MemoryManagerConfig::try_new_limit(pool_size, 1.0).unwrap(), ); - let runtime = Arc::new(RuntimeEnv::new(runtime_config).unwrap()); - let collected = collect(sort.clone(), runtime).await.unwrap(); + let session_config = SessionConfig::new().with_runtime_config(runtime_config); + let session_ctx = SessionContext::with_config(session_config); + + let task_ctx = session_ctx.task_ctx(); + let collected = collect(sort.clone(), task_ctx).await.unwrap(); let expected = partitions_to_sorted_vec(&input); let actual = batches_to_vec(&collected); diff --git a/datafusion/tests/parquet_pruning.rs b/datafusion/tests/parquet_pruning.rs index 9869a1f6b16ac..c5428e499a68a 100644 --- a/datafusion/tests/parquet_pruning.rs +++ b/datafusion/tests/parquet_pruning.rs @@ -37,7 +37,7 @@ use datafusion::{ accept, file_format::ParquetExec, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor, }, - prelude::{ExecutionConfig, ExecutionContext}, + prelude::{SessionConfig, SessionContext}, scalar::ScalarValue, }; use parquet::{arrow::ArrowWriter, file::properties::WriterProperties}; @@ -161,7 +161,7 @@ async fn prune_disabled() { ); // same query, without pruning - let config = ExecutionConfig::new().with_parquet_pruning(false); + let config = SessionConfig::new().with_parquet_pruning(false); let output = ContextWithParquet::with_config(Scenario::Timestamps, config) .await @@ -424,7 +424,7 @@ struct ContextWithParquet { /// when dropped file: NamedTempFile, provider: Arc, - ctx: ExecutionContext, + ctx: SessionContext, } /// The output of running one of the test cases @@ -472,15 +472,15 @@ impl TestOutput { /// and the appropriate scenario impl ContextWithParquet { async fn new(scenario: Scenario) -> Self { - Self::with_config(scenario, ExecutionConfig::new()).await + Self::with_config(scenario, SessionConfig::new()).await } - async fn with_config(scenario: Scenario, config: ExecutionConfig) -> Self { + async fn with_config(scenario: Scenario, config: SessionConfig) -> Self { let file = make_test_file(scenario).await; let parquet_path = file.path().to_string_lossy(); // now, setup a the file as a data source and run a query against it - let mut ctx = ExecutionContext::with_config(config); + let mut ctx = SessionContext::with_config(config); ctx.register_parquet("t", &parquet_path).await.unwrap(); let provider = ctx.deregister_table("t").unwrap().unwrap(); @@ -537,8 +537,8 @@ impl ContextWithParquet { .await .expect("creating physical plan"); - let runtime = self.ctx.state.lock().runtime_env.clone(); - let results = datafusion::physical_plan::collect(physical_plan.clone(), runtime) + let task_ctx = self.ctx.task_ctx(); + let results = datafusion::physical_plan::collect(physical_plan.clone(), task_ctx) .await .expect("Running"); diff --git a/datafusion/tests/path_partition.rs b/datafusion/tests/path_partition.rs index 178e318775c9f..64f6462c082b7 100644 --- a/datafusion/tests/path_partition.rs +++ b/datafusion/tests/path_partition.rs @@ -32,14 +32,14 @@ use datafusion::{ }, error::{DataFusionError, Result}, physical_plan::ColumnStatistics, - prelude::ExecutionContext, + prelude::SessionContext, test_util::{self, arrow_test_data, parquet_test_data}, }; use futures::{stream, StreamExt}; #[tokio::test] async fn csv_filter_with_file_col() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_partitioned_aggregate_csv( &mut ctx, @@ -75,7 +75,7 @@ async fn csv_filter_with_file_col() -> Result<()> { #[tokio::test] async fn csv_projection_on_partition() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_partitioned_aggregate_csv( &mut ctx, @@ -111,7 +111,7 @@ async fn csv_projection_on_partition() -> Result<()> { #[tokio::test] async fn csv_grouping_by_partition() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_partitioned_aggregate_csv( &mut ctx, @@ -145,7 +145,7 @@ async fn csv_grouping_by_partition() -> Result<()> { #[tokio::test] async fn parquet_multiple_partitions() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_partitioned_alltypes_parquet( &mut ctx, @@ -187,7 +187,7 @@ async fn parquet_multiple_partitions() -> Result<()> { #[tokio::test] async fn parquet_statistics() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_partitioned_alltypes_parquet( &mut ctx, @@ -246,7 +246,7 @@ async fn parquet_statistics() -> Result<()> { #[tokio::test] async fn parquet_overlapping_columns() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // `id` is both a column of the file and a partitioning col register_partitioned_alltypes_parquet( @@ -272,7 +272,7 @@ async fn parquet_overlapping_columns() -> Result<()> { } fn register_partitioned_aggregate_csv( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, store_paths: &[&str], partition_cols: &[&str], table_path: &str, @@ -295,7 +295,7 @@ fn register_partitioned_aggregate_csv( } async fn register_partitioned_alltypes_parquet( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, store_paths: &[&str], partition_cols: &[&str], table_path: &str, diff --git a/datafusion/tests/provider_filter_pushdown.rs b/datafusion/tests/provider_filter_pushdown.rs index 203fb7ce56ff6..62d1cac51c090 100644 --- a/datafusion/tests/provider_filter_pushdown.rs +++ b/datafusion/tests/provider_filter_pushdown.rs @@ -21,8 +21,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion::datasource::datasource::{TableProvider, TableProviderFilterPushDown}; use datafusion::error::Result; -use datafusion::execution::context::ExecutionContext; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::{SessionContext, TaskContext}; use datafusion::logical_plan::Expr; use datafusion::physical_plan::common::SizedRecordBatchStream; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -88,7 +87,7 @@ impl ExecutionPlan for CustomPlan { async fn execute( &self, partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { let metrics = ExecutionPlanMetricsSet::new(); let tracking_metrics = MemTrackingMetrics::new(&metrics, partition); @@ -174,7 +173,7 @@ async fn assert_provider_row_count(value: i64, expected_count: u64) -> Result<() one_batch: create_batch(1, 5)?, }; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let df = ctx .read_table(Arc::new(provider.clone()))? .filter(col("flag").eq(lit(value)))? diff --git a/datafusion/tests/sql/aggregates.rs b/datafusion/tests/sql/aggregates.rs index 187778c02fe95..8c5640c04344b 100644 --- a/datafusion/tests/sql/aggregates.rs +++ b/datafusion/tests/sql/aggregates.rs @@ -20,14 +20,14 @@ use datafusion::scalar::ScalarValue; #[tokio::test] async fn csv_query_avg_multi_batch() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = ctx.task_ctx(); + let results = collect(plan, task_ctx).await.unwrap(); let batch = &results[0]; let column = batch.column(0); let array = column.as_any().downcast_ref::().unwrap(); @@ -41,7 +41,7 @@ async fn csv_query_avg_multi_batch() -> Result<()> { #[tokio::test] async fn csv_query_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT avg(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -53,7 +53,7 @@ async fn csv_query_avg() -> Result<()> { #[tokio::test] async fn csv_query_covariance_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT covar_pop(c2, c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -65,7 +65,7 @@ async fn csv_query_covariance_1() -> Result<()> { #[tokio::test] async fn csv_query_covariance_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT covar(c2, c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -77,7 +77,7 @@ async fn csv_query_covariance_2() -> Result<()> { #[tokio::test] async fn csv_query_correlation() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT corr(c2, c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -89,7 +89,7 @@ async fn csv_query_correlation() -> Result<()> { #[tokio::test] async fn csv_query_variance_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT var_pop(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -101,7 +101,7 @@ async fn csv_query_variance_1() -> Result<()> { #[tokio::test] async fn csv_query_variance_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT var_pop(c6) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -113,7 +113,7 @@ async fn csv_query_variance_2() -> Result<()> { #[tokio::test] async fn csv_query_variance_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT var_pop(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -125,7 +125,7 @@ async fn csv_query_variance_3() -> Result<()> { #[tokio::test] async fn csv_query_variance_4() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT var(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -137,7 +137,7 @@ async fn csv_query_variance_4() -> Result<()> { #[tokio::test] async fn csv_query_variance_5() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT var_samp(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -149,7 +149,7 @@ async fn csv_query_variance_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT stddev_pop(c2) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -161,7 +161,7 @@ async fn csv_query_stddev_1() -> Result<()> { #[tokio::test] async fn csv_query_stddev_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT stddev_pop(c6) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -173,7 +173,7 @@ async fn csv_query_stddev_2() -> Result<()> { #[tokio::test] async fn csv_query_stddev_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT stddev_pop(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -185,7 +185,7 @@ async fn csv_query_stddev_3() -> Result<()> { #[tokio::test] async fn csv_query_stddev_4() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT stddev(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -197,7 +197,7 @@ async fn csv_query_stddev_4() -> Result<()> { #[tokio::test] async fn csv_query_stddev_5() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT stddev_samp(c12) FROM aggregate_test_100"; let mut actual = execute(&mut ctx, sql).await; @@ -209,7 +209,7 @@ async fn csv_query_stddev_5() -> Result<()> { #[tokio::test] async fn csv_query_stddev_6() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "select stddev(sq.column1) from (values (1.1), (2.0), (3.0)) as sq"; let mut actual = execute(&mut ctx, sql).await; @@ -221,7 +221,7 @@ async fn csv_query_stddev_6() -> Result<()> { #[tokio::test] async fn csv_query_median_1() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c2) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; @@ -232,7 +232,7 @@ async fn csv_query_median_1() -> Result<()> { #[tokio::test] async fn csv_query_median_2() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c6) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; @@ -243,7 +243,7 @@ async fn csv_query_median_2() -> Result<()> { #[tokio::test] async fn csv_query_median_3() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_median(c12) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; @@ -254,10 +254,10 @@ async fn csv_query_median_3() -> Result<()> { #[tokio::test] async fn csv_query_external_table_count() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| COUNT(aggregate_test_100.c12) |", @@ -271,12 +271,12 @@ async fn csv_query_external_table_count() { #[tokio::test] async fn csv_query_external_table_sum() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // cast smallint and int to bigint to avoid overflow during calculation register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT SUM(CAST(c7 AS BIGINT)), SUM(CAST(c8 AS BIGINT)) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------------------+-------------------------------------------+", "| SUM(CAST(aggregate_test_100.c7 AS Int64)) | SUM(CAST(aggregate_test_100.c8 AS Int64)) |", @@ -289,10 +289,10 @@ async fn csv_query_external_table_sum() { #[tokio::test] async fn csv_query_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT count(c12) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| COUNT(aggregate_test_100.c12) |", @@ -306,10 +306,10 @@ async fn csv_query_count() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT count(distinct c2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------+", "| COUNT(DISTINCT aggregate_test_100.c2) |", @@ -323,10 +323,10 @@ async fn csv_query_count_distinct() -> Result<()> { #[tokio::test] async fn csv_query_count_distinct_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT count(distinct c2 % 2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------------+", "| COUNT(DISTINCT aggregate_test_100.c2 % Int64(2)) |", @@ -340,10 +340,10 @@ async fn csv_query_count_distinct_expr() -> Result<()> { #[tokio::test] async fn csv_query_count_star() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(*) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -356,10 +356,10 @@ async fn csv_query_count_star() { #[tokio::test] async fn csv_query_count_one() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -372,10 +372,10 @@ async fn csv_query_count_one() { #[tokio::test] async fn csv_query_approx_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT approx_distinct(c9) count_c9, approx_distinct(cast(c9 as varchar)) count_c9_str FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+--------------+", "| count_c9 | count_c9_str |", @@ -412,7 +412,7 @@ async fn csv_query_approx_count() -> Result<()> { // float values. #[tokio::test] async fn csv_query_approx_percentile_cont() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; // Generate an assertion that the estimated $percentile value for $column is @@ -420,7 +420,7 @@ async fn csv_query_approx_percentile_cont() -> Result<()> { macro_rules! percentile_test { ($ctx:ident, column=$column:literal, percentile=$percentile:literal, actual=$actual:literal) => { let sql = format!("SELECT (ABS(1 - CAST(approx_percentile_cont({}, {}) AS DOUBLE) / {}) < 0.05) AS q FROM aggregate_test_100", $column, $percentile, $actual); - let actual = execute_to_batches(&mut ctx, &sql).await; + let actual = execute_to_batches(&ctx, &sql).await; // // "+------+", // "| q |", @@ -478,10 +478,10 @@ async fn csv_query_approx_percentile_cont() -> Result<()> { #[tokio::test] async fn csv_query_sum_crossjoin() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT a.c1, b.c1, SUM(a.c2) FROM aggregate_test_100 as a CROSS JOIN aggregate_test_100 as b GROUP BY a.c1, b.c1 ORDER BY a.c1, b.c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----------+", "| c1 | c1 | SUM(a.c2) |", @@ -518,9 +518,9 @@ async fn csv_query_sum_crossjoin() { #[tokio::test] async fn query_count_without_from() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT count(1 + 1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", "| COUNT(Int64(1) + Int64(1)) |", @@ -534,11 +534,11 @@ async fn query_count_without_from() -> Result<()> { #[tokio::test] async fn csv_query_array_agg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 2) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------------------------------------------+", "| ARRAYAGG(test.c13) |", @@ -552,11 +552,11 @@ async fn csv_query_array_agg() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 LIMIT 0) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------+", "| ARRAYAGG(test.c13) |", @@ -570,11 +570,11 @@ async fn csv_query_array_agg_empty() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_one() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT array_agg(c13) FROM (SELECT * FROM aggregate_test_100 ORDER BY c13 LIMIT 1) test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------+", "| ARRAYAGG(test.c13) |", @@ -588,10 +588,10 @@ async fn csv_query_array_agg_one() -> Result<()> { #[tokio::test] async fn csv_query_array_agg_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT array_agg(distinct c2) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // The results for this query should be something like the following: // +------------------------------------------+ @@ -638,7 +638,7 @@ async fn csv_query_array_agg_distinct() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_sum() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( @@ -655,11 +655,11 @@ async fn aggregate_timestamps_sum() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT count(nanos), count(micros), count(millis), count(secs) FROM t", ) .await; @@ -678,11 +678,11 @@ async fn aggregate_timestamps_count() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_min() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT min(nanos), min(micros), min(millis), min(secs) FROM t", ) .await; @@ -701,11 +701,11 @@ async fn aggregate_timestamps_min() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT max(nanos), max(micros), max(millis), max(secs) FROM t", ) .await; @@ -724,7 +724,7 @@ async fn aggregate_timestamps_max() -> Result<()> { #[tokio::test] async fn aggregate_timestamps_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let results = plan_and_collect( diff --git a/datafusion/tests/sql/avro.rs b/datafusion/tests/sql/avro.rs index 82d91a0bd4812..5289efa981c75 100644 --- a/datafusion/tests/sql/avro.rs +++ b/datafusion/tests/sql/avro.rs @@ -17,7 +17,7 @@ use super::*; -async fn register_alltypes_avro(ctx: &mut ExecutionContext) { +async fn register_alltypes_avro(ctx: &mut SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "alltypes_plain", @@ -30,12 +30,12 @@ async fn register_alltypes_avro(ctx: &mut ExecutionContext) { #[tokio::test] async fn avro_query() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_avro(&mut ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -71,7 +71,7 @@ async fn avro_query_multiple_files() { ) .unwrap(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_avro( "alltypes_plain", table_path.display().to_string().as_str(), @@ -82,7 +82,7 @@ async fn avro_query_multiple_files() { // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -111,7 +111,7 @@ async fn avro_query_multiple_files() { #[tokio::test] async fn avro_single_nan_schema() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::test_util::arrow_test_data(); ctx.register_avro( "single_nan", @@ -134,7 +134,7 @@ async fn avro_single_nan_schema() { #[tokio::test] async fn avro_explain() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_avro(&mut ctx).await; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; diff --git a/datafusion/tests/sql/create_drop.rs b/datafusion/tests/sql/create_drop.rs index 45f2a36047c55..ad786610af4ac 100644 --- a/datafusion/tests/sql/create_drop.rs +++ b/datafusion/tests/sql/create_drop.rs @@ -23,14 +23,14 @@ use super::*; #[tokio::test] async fn create_table_as() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; ctx.sql(sql).await.unwrap(); let sql_all = "SELECT * FROM my_table order by c1 LIMIT 1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; + let results_all = execute_to_batches(&ctx, sql_all).await; let expected = vec![ "+---------+----------------+------+", @@ -47,7 +47,7 @@ async fn create_table_as() -> Result<()> { #[tokio::test] async fn drop_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "CREATE TABLE my_table AS SELECT * FROM aggregate_simple"; @@ -67,10 +67,10 @@ async fn drop_table() -> Result<()> { #[tokio::test] async fn csv_query_create_external_table() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT c1, c2, c3, c4, c5, c6, c7, c8, c9, 10, c11, c12, c13 FROM aggregate_test_100 LIMIT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+----+-------+------------+----------------------+----+-------+------------+-----------+-------------+--------------------+--------------------------------+", "| c1 | c2 | c3 | c4 | c5 | c6 | c7 | c8 | c9 | Int64(10) | c11 | c12 | c13 |", @@ -83,7 +83,7 @@ async fn csv_query_create_external_table() { #[tokio::test] async fn create_external_table_with_timestamps() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let data = "Jorge,2018-12-13T12:12:10.011Z\n\ Andrew,2018-11-13T17:11:10.011Z"; diff --git a/datafusion/tests/sql/errors.rs b/datafusion/tests/sql/errors.rs index 92b634dd5e966..ef4bb6ee34ea7 100644 --- a/datafusion/tests/sql/errors.rs +++ b/datafusion/tests/sql/errors.rs @@ -37,8 +37,8 @@ async fn test_cast_expressions_error() -> Result<()> { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(plan, runtime).await; + let task_ctx = ctx.task_ctx(); + let result = collect(plan, task_ctx).await; match result { Ok(_) => panic!("expected error"), @@ -54,7 +54,7 @@ async fn test_cast_expressions_error() -> Result<()> { #[tokio::test] async fn test_aggregation_with_bad_arguments() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(DISTINCT) FROM aggregate_test_100"; let logical_plan = ctx.create_logical_plan(sql); @@ -71,7 +71,7 @@ async fn test_aggregation_with_bad_arguments() -> Result<()> { #[tokio::test] async fn query_cte_incorrect() -> Result<()> { - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // self reference let sql = "WITH t AS (SELECT * FROM t) SELECT * from u"; @@ -105,7 +105,7 @@ async fn query_cte_incorrect() -> Result<()> { #[tokio::test] async fn test_select_wildcard_without_table() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let sql = "SELECT * "; let actual = ctx.sql(sql).await; match actual { @@ -122,7 +122,7 @@ async fn test_select_wildcard_without_table() -> Result<()> { #[tokio::test] async fn invalid_qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; for table_ref in &[ diff --git a/datafusion/tests/sql/explain.rs b/datafusion/tests/sql/explain.rs index 00842b5eb8abf..b85228016e507 100644 --- a/datafusion/tests/sql/explain.rs +++ b/datafusion/tests/sql/explain.rs @@ -18,7 +18,7 @@ use arrow::datatypes::{DataType, Field, Schema}; use datafusion::{ logical_plan::{LogicalPlan, LogicalPlanBuilder, PlanType}, - prelude::ExecutionContext, + prelude::SessionContext, }; #[test] @@ -39,7 +39,7 @@ fn optimize_explain() { } // now optimize the plan and expect to see more plans - let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap(); + let optimized_plan = SessionContext::new().optimize(&plan).unwrap(); if let LogicalPlan::Explain(e) = &optimized_plan { // should have more than one plan assert!( diff --git a/datafusion/tests/sql/explain_analyze.rs b/datafusion/tests/sql/explain_analyze.rs index 2051bdd1b80b7..84946261ad20b 100644 --- a/datafusion/tests/sql/explain_analyze.rs +++ b/datafusion/tests/sql/explain_analyze.rs @@ -21,8 +21,8 @@ use super::*; async fn explain_analyze_baseline_metrics() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE // and then validate the presence of baseline metrics for supported operators - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(3); + let mut ctx = SessionContext::with_config(config); register_aggregate_csv_by_sql(&mut ctx).await; // a query with as many operators as we have metrics for let sql = "EXPLAIN ANALYZE \ @@ -41,8 +41,8 @@ async fn explain_analyze_baseline_metrics() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let physical_plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(physical_plan.clone(), runtime).await.unwrap(); + let task_ctx = ctx.task_ctx(); + let results = collect(physical_plan.clone(), task_ctx).await.unwrap(); let formatted = arrow::util::pretty::pretty_format_batches(&results) .unwrap() .to_string(); @@ -168,7 +168,7 @@ async fn explain_analyze_baseline_metrics() { async fn csv_explain_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; @@ -329,8 +329,8 @@ async fn csv_explain_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.expect(&msg); + let task_ctx = ctx.task_ctx(); + let results = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -342,7 +342,7 @@ async fn csv_explain_plans() { #[tokio::test] async fn csv_explain_verbose() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&mut ctx, sql).await; @@ -365,7 +365,7 @@ async fn csv_explain_verbose() { async fn csv_explain_verbose_plans() { // This test verify the look of each plan in its full cycle plan creation - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN VERBOSE SELECT c1 FROM aggregate_test_100 where c2 > 10"; @@ -527,8 +527,8 @@ async fn csv_explain_verbose_plans() { // // Execute plan let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.expect(&msg); + let task_ctx = ctx.task_ctx(); + let results = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&results); // flatten to a single string let actual = actual.into_iter().map(|r| r.join("\t")).collect::(); @@ -545,7 +545,7 @@ async fn csv_explain_verbose_plans() { async fn explain_analyze_runs_optimizers() { // repro for https://github.com/apache/arrow-datafusion/issues/917 // where EXPLAIN ANALYZE was not correctly running optiimizer - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // This happens as an optimization pass where count(*) can be @@ -553,7 +553,7 @@ async fn explain_analyze_runs_optimizers() { let expected = "EmptyExec: produce_one_row=true"; let sql = "EXPLAIN SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let actual = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -561,7 +561,7 @@ async fn explain_analyze_runs_optimizers() { // EXPLAIN ANALYZE should work the same let sql = "EXPLAIN ANALYZE SELECT count(*) from alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let actual = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -570,7 +570,7 @@ async fn explain_analyze_runs_optimizers() { #[tokio::test] async fn tpch_explain_q10() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_tpch_csv(&mut ctx, "customer").await?; register_tpch_csv(&mut ctx, "orders").await?; @@ -633,8 +633,8 @@ order by #[tokio::test] async fn test_physical_plan_display_indent() { // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(3); + let mut ctx = SessionContext::with_config(config); register_aggregate_csv(&mut ctx).await.unwrap(); let sql = "SELECT c1, MAX(c12), MIN(c12) as the_min \ FROM aggregate_test_100 \ @@ -679,8 +679,8 @@ async fn test_physical_plan_display_indent() { #[tokio::test] async fn test_physical_plan_display_indent_multi_children() { // Hard code target_partitions as it appears in the RepartitionExec output - let config = ExecutionConfig::new().with_target_partitions(3); - let mut ctx = ExecutionContext::with_config(config); + let config = SessionConfig::new().with_target_partitions(3); + let mut ctx = SessionContext::with_config(config); // ensure indenting works for nodes with multiple children register_aggregate_csv(&mut ctx).await.unwrap(); let sql = "SELECT c1 \ @@ -731,7 +731,7 @@ async fn test_physical_plan_display_indent_multi_children() { async fn csv_explain() { // This test uses the execute function that create full plan cycle: logical, optimized logical, and physical, // then execute the physical plan and return the final explain results - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN SELECT c1 FROM aggregate_test_100 where c2 > 10"; let actual = execute(&mut ctx, sql).await; @@ -766,10 +766,10 @@ async fn csv_explain() { #[tokio::test] async fn csv_explain_analyze() { // This test uses the execute function to run an actual plan under EXPLAIN ANALYZE - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN ANALYZE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -787,11 +787,11 @@ async fn csv_explain_analyze() { #[tokio::test] async fn csv_explain_analyze_verbose() { // This test uses the execute function to run an actual plan under EXPLAIN VERBOSE ANALYZE - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "EXPLAIN ANALYZE VERBOSE SELECT count(*), c1 FROM aggregate_test_100 group by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); diff --git a/datafusion/tests/sql/expr.rs b/datafusion/tests/sql/expr.rs index fbbc44e1096b2..f70c8da7e4b70 100644 --- a/datafusion/tests/sql/expr.rs +++ b/datafusion/tests/sql/expr.rs @@ -19,13 +19,13 @@ use super::*; #[tokio::test] async fn case_when() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE WHEN c1 = 'a' THEN 1 \ WHEN c1 = 'b' THEN 2 \ END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------------------------------------------------+", "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) END |", @@ -42,13 +42,13 @@ async fn case_when() -> Result<()> { #[tokio::test] async fn case_when_else() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE WHEN c1 = 'a' THEN 1 \ WHEN c1 = 'b' THEN 2 \ ELSE 999 END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------------------------------------------------------------------------------+", "| CASE WHEN #t1.c1 = Utf8(\"a\") THEN Int64(1) WHEN #t1.c1 = Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", @@ -65,13 +65,13 @@ async fn case_when_else() -> Result<()> { #[tokio::test] async fn case_when_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE c1 WHEN 'a' THEN 1 \ WHEN 'b' THEN 2 \ END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------------------------------------------+", "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) END |", @@ -88,13 +88,13 @@ async fn case_when_with_base_expr() -> Result<()> { #[tokio::test] async fn case_when_else_with_base_expr() -> Result<()> { - let mut ctx = create_case_context()?; + let ctx = create_case_context()?; let sql = "SELECT \ CASE c1 WHEN 'a' THEN 1 \ WHEN 'b' THEN 2 \ ELSE 999 END \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------------------------------------------------------------------+", "| CASE #t1.c1 WHEN Utf8(\"a\") THEN Int64(1) WHEN Utf8(\"b\") THEN Int64(2) ELSE Int64(999) END |", @@ -124,10 +124,10 @@ async fn query_not() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT NOT c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------+", "| NOT test.c1 |", @@ -143,7 +143,7 @@ async fn query_not() -> Result<()> { #[tokio::test] async fn csv_query_sum_cast() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; // c8 = i32; c9 = i64 let sql = "SELECT c8 + c9 FROM aggregate_test_100"; @@ -166,10 +166,10 @@ async fn query_is_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| test.c1 IS NULL |", @@ -198,10 +198,10 @@ async fn query_is_not_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT c1 IS NOT NULL FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+", "| test.c1 IS NOT NULL |", @@ -219,10 +219,10 @@ async fn query_is_not_null() -> Result<()> { async fn query_without_from() -> Result<()> { // Test for SELECT without FROM. // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| Int64(1) |", @@ -233,7 +233,7 @@ async fn query_without_from() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT 1+2, 3/4, cos(0)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+---------------------+---------------+", "| Int64(1) + Int64(2) | Int64(3) / Int64(4) | cos(Int64(0)) |", @@ -262,10 +262,10 @@ async fn query_scalar_minus_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT 4 - c1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------+", "| Int64(4) Minus test.c1 |", @@ -669,9 +669,9 @@ async fn test_random_expression() -> Result<()> { #[tokio::test] async fn case_with_bool_type_result() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select case when 'cpu' != 'cpu' then true else false end"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------------------------------------------------------+", "| CASE WHEN Utf8(\"cpu\") != Utf8(\"cpu\") THEN Boolean(true) ELSE Boolean(false) END |", @@ -685,7 +685,7 @@ async fn case_with_bool_type_result() -> Result<()> { #[tokio::test] async fn in_list_array() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT c1 IN ('a', 'c') AS utf8_in_true @@ -694,7 +694,7 @@ async fn in_list_array() -> Result<()> { ,c1 NOT IN ('a', 'c') AS utf8_not_in_false ,NULL IN ('a', 'c') AS utf8_in_null FROM aggregate_test_100 WHERE c12 < 0.05"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+---------------+------------------+-------------------+--------------+", "| utf8_in_true | utf8_in_false | utf8_not_in_true | utf8_not_in_false | utf8_in_null |", @@ -810,11 +810,11 @@ async fn test_in_list_scalar() -> Result<()> { #[tokio::test] async fn csv_query_boolean_eq_neq() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_boolean(&mut ctx).await.unwrap(); // verify the plumbing is all hooked up for eq and neq let sql = "SELECT a, b, a = b as eq, b = true as eq_scalar, a != b as neq, a != true as neq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+------------+", @@ -836,11 +836,11 @@ async fn csv_query_boolean_eq_neq() { #[tokio::test] async fn csv_query_boolean_lt_lt_eq() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_boolean(&mut ctx).await.unwrap(); // verify the plumbing is all hooked up for < and <= let sql = "SELECT a, b, a < b as lt, b = true as lt_scalar, a <= b as lt_eq, a <= true as lt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+--------------+", @@ -862,11 +862,11 @@ async fn csv_query_boolean_lt_lt_eq() { #[tokio::test] async fn csv_query_boolean_gt_gt_eq() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_boolean(&mut ctx).await.unwrap(); // verify the plumbing is all hooked up for > and >= let sql = "SELECT a, b, a > b as gt, b = true as gt_scalar, a >= b as gt_eq, a >= true as gt_eq_scalar FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+--------------+", @@ -888,7 +888,7 @@ async fn csv_query_boolean_gt_gt_eq() { #[tokio::test] async fn csv_query_boolean_distinct_from() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_boolean(&mut ctx).await.unwrap(); // verify the plumbing is all hooked up for is distinct from and is not distinct from let sql = "SELECT a, b, \ @@ -897,7 +897,7 @@ async fn csv_query_boolean_distinct_from() { a is not distinct from b as ndf, \ a is not distinct from true as ndf_scalar \ FROM t1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+-------+-----------+-------+------------+", @@ -919,7 +919,7 @@ async fn csv_query_boolean_distinct_from() { #[tokio::test] async fn csv_query_nullif_divide_by_0() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c8/nullif(c7, 0) FROM aggregate_test_100"; let actual = execute(&mut ctx, sql).await; @@ -941,10 +941,10 @@ async fn csv_query_nullif_divide_by_0() -> Result<()> { } #[tokio::test] async fn csv_count_star() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(*), COUNT(1) AS c, COUNT(c1) FROM aggregate_test_100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+-----+------------------------------+", "| COUNT(UInt8(1)) | c | COUNT(aggregate_test_100.c1) |", diff --git a/datafusion/tests/sql/functions.rs b/datafusion/tests/sql/functions.rs index cf2475792a4e6..0e37b3c5cabcd 100644 --- a/datafusion/tests/sql/functions.rs +++ b/datafusion/tests/sql/functions.rs @@ -37,10 +37,10 @@ async fn sqrt_f32_vs_f64() -> Result<()> { #[tokio::test] async fn csv_query_cast() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT CAST(c12 AS float) FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------------------+", @@ -57,11 +57,11 @@ async fn csv_query_cast() -> Result<()> { #[tokio::test] async fn csv_query_cast_literal() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c12, CAST(1 AS float) FROM aggregate_test_100 WHERE c12 > CAST(0 AS float) LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------+---------------------------+", @@ -93,10 +93,10 @@ async fn query_concat() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT concat(c1, '-hi-', cast(c2 as varchar)) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------------------------+", "| concat(test.c1,Utf8(\"-hi-\"),CAST(test.c2 AS Utf8)) |", @@ -129,7 +129,7 @@ async fn query_array() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT array(c1, cast(c2 as varchar)) FROM test"; let actual = execute(&mut ctx, sql).await; @@ -160,10 +160,10 @@ async fn query_count_distinct() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(DISTINCT c1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| COUNT(DISTINCT test.c1) |", diff --git a/datafusion/tests/sql/group_by.rs b/datafusion/tests/sql/group_by.rs index 38a0c2e442045..5430ca9d09f0e 100644 --- a/datafusion/tests/sql/group_by.rs +++ b/datafusion/tests/sql/group_by.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_group_by_int_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c2, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+-----------------------------+", "| c2 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", @@ -40,12 +40,12 @@ async fn csv_query_group_by_int_min_max() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "SELECT COUNT(*) as cnt, c1 FROM aggregate_simple GROUP BY c1 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+---------+", @@ -65,12 +65,12 @@ async fn csv_query_group_by_float32() -> Result<()> { #[tokio::test] async fn csv_query_group_by_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "SELECT COUNT(*) as cnt, c2 FROM aggregate_simple GROUP BY c2 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+----------------+", @@ -90,12 +90,12 @@ async fn csv_query_group_by_float64() -> Result<()> { #[tokio::test] async fn csv_query_group_by_boolean() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "SELECT COUNT(*) as cnt, c3 FROM aggregate_simple GROUP BY c3 ORDER BY cnt DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+-------+", @@ -112,10 +112,10 @@ async fn csv_query_group_by_boolean() -> Result<()> { #[tokio::test] async fn csv_query_group_by_two_columns() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, c2, MIN(c3) FROM aggregate_test_100 GROUP BY c1, c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+----------------------------+", "| c1 | c2 | MIN(aggregate_test_100.c3) |", @@ -153,10 +153,10 @@ async fn csv_query_group_by_two_columns() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+------+", "| c1 | m |", @@ -171,14 +171,14 @@ async fn csv_query_group_by_and_having() -> Result<()> { #[tokio::test] async fn csv_query_group_by_and_having_and_where() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, MIN(c3) AS m FROM aggregate_test_100 WHERE c1 IN ('a', 'b') GROUP BY c1 HAVING m < -100 AND MAX(c3) > 70"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+------+", "| c1 | m |", @@ -192,10 +192,10 @@ async fn csv_query_group_by_and_having_and_where() -> Result<()> { #[tokio::test] async fn csv_query_having_without_group_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, c2, c3 FROM aggregate_test_100 HAVING c2 >= 4 AND c3 > 90"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----+", "| c1 | c2 | c3 |", @@ -213,10 +213,10 @@ async fn csv_query_having_without_group_by() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, avg(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+", "| c1 | AVG(aggregate_test_100.c12) |", @@ -234,10 +234,10 @@ async fn csv_query_group_by_avg() -> Result<()> { #[tokio::test] async fn csv_query_group_by_int_count() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, count(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------------------------------+", "| c1 | COUNT(aggregate_test_100.c12) |", @@ -255,10 +255,10 @@ async fn csv_query_group_by_int_count() -> Result<()> { #[tokio::test] async fn csv_query_group_with_aliased_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, count(c12) AS count FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------+", "| c1 | count |", @@ -276,10 +276,10 @@ async fn csv_query_group_with_aliased_aggregate() -> Result<()> { #[tokio::test] async fn csv_query_group_by_string_min_max() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, MIN(c12), MAX(c12) FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------+-----------------------------+", "| c1 | MIN(aggregate_test_100.c12) | MAX(aggregate_test_100.c12) |", @@ -312,11 +312,11 @@ async fn query_group_on_null() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1 FROM test GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Note that the results also // include a row for NULL (c1=NULL, count = 1) @@ -371,11 +371,11 @@ async fn query_group_on_null_multi_col() -> Result<()> { let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c1, c2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Note that the results also include values for null // include a row for NULL (c1=NULL, count = 1) @@ -393,14 +393,14 @@ async fn query_group_on_null_multi_col() -> Result<()> { // Also run query with group columns reversed (results should be the same) let sql = "SELECT COUNT(*), c1, c2 FROM test GROUP BY c2, c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_sorted_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn csv_group_by_date() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new("date", DataType::Date32, false), Field::new("cnt", DataType::Int32, false), @@ -430,7 +430,7 @@ async fn csv_group_by_date() -> Result<()> { ctx.register_table("dates", Arc::new(table))?; let sql = "SELECT SUM(cnt) FROM dates GROUP BY date"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| SUM(dates.cnt) |", diff --git a/datafusion/tests/sql/information_schema.rs b/datafusion/tests/sql/information_schema.rs index d93f0d7328d39..260a1a60bede3 100644 --- a/datafusion/tests/sql/information_schema.rs +++ b/datafusion/tests/sql/information_schema.rs @@ -29,7 +29,7 @@ use super::*; #[tokio::test] async fn information_schema_tables_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") .await @@ -42,9 +42,8 @@ async fn information_schema_tables_not_exist_by_default() { #[tokio::test] async fn information_schema_tables_no_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); let result = plan_and_collect(&mut ctx, "SELECT * from information_schema.tables") .await @@ -63,9 +62,8 @@ async fn information_schema_tables_no_tables() { #[tokio::test] async fn information_schema_tables_tables_default_catalog() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); // Now, register an empty table ctx.register_table("t", table_with_sequence(1, 1).unwrap()) @@ -109,9 +107,8 @@ async fn information_schema_tables_tables_default_catalog() { #[tokio::test] async fn information_schema_tables_tables_with_multiple_catalogs() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); schema @@ -181,9 +178,8 @@ async fn information_schema_tables_table_types() { } } - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("physical", Arc::new(TestTable(TableType::Base))) .unwrap(); @@ -212,7 +208,7 @@ async fn information_schema_tables_table_types() { #[tokio::test] async fn information_schema_show_tables_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let mut ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -225,9 +221,8 @@ async fn information_schema_show_tables_no_information_schema() { #[tokio::test] async fn information_schema_show_tables() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -253,7 +248,7 @@ async fn information_schema_show_tables() { #[tokio::test] async fn information_schema_show_columns_no_information_schema() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let mut ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -267,7 +262,7 @@ async fn information_schema_show_columns_no_information_schema() { #[tokio::test] async fn information_schema_show_columns_like_where() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let mut ctx = SessionContext::with_config(SessionConfig::new()); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -288,9 +283,8 @@ async fn information_schema_show_columns_like_where() { #[tokio::test] async fn information_schema_show_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -326,9 +320,8 @@ async fn information_schema_show_columns() { // test errors with WHERE and LIKE #[tokio::test] async fn information_schema_show_columns_full_extended() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -353,9 +346,8 @@ async fn information_schema_show_columns_full_extended() { #[tokio::test] async fn information_schema_show_table_table_names() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); ctx.register_table("t", table_with_sequence(1, 1).unwrap()) .unwrap(); @@ -397,7 +389,7 @@ async fn information_schema_show_table_table_names() { #[tokio::test] async fn show_unsupported() { - let mut ctx = ExecutionContext::with_config(ExecutionConfig::new()); + let mut ctx = SessionContext::with_config(SessionConfig::new()); let err = plan_and_collect(&mut ctx, "SHOW SOMETHING_UNKNOWN") .await @@ -408,7 +400,7 @@ async fn show_unsupported() { #[tokio::test] async fn information_schema_columns_not_exist_by_default() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let err = plan_and_collect(&mut ctx, "SELECT * from information_schema.columns") .await @@ -456,9 +448,8 @@ fn table_with_many_types() -> Arc { #[tokio::test] async fn information_schema_columns() { - let mut ctx = ExecutionContext::with_config( - ExecutionConfig::new().with_information_schema(true), - ); + let mut ctx = + SessionContext::with_config(SessionConfig::new().with_information_schema(true)); let catalog = MemoryCatalogProvider::new(); let schema = MemorySchemaProvider::new(); @@ -495,7 +486,7 @@ async fn information_schema_columns() { /// Execute SQL and return results async fn plan_and_collect( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await diff --git a/datafusion/tests/sql/intersection.rs b/datafusion/tests/sql/intersection.rs index d28dd8079fa99..eec22eecdf55d 100644 --- a/datafusion/tests/sql/intersection.rs +++ b/datafusion/tests/sql/intersection.rs @@ -23,8 +23,8 @@ async fn intersect_with_null_not_equal() { INTERSECT SELECT * FROM (SELECT null AS id1, 2 AS id2) t2"; let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -41,19 +41,19 @@ async fn intersect_with_null_equal() { "+-----+-----+", ]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } #[tokio::test] async fn test_intersect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT ALL SELECT int_col, double_col FROM alltypes_plain LIMIT 4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", @@ -70,11 +70,11 @@ async fn test_intersect_all() -> Result<()> { #[tokio::test] async fn test_intersect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 INTERSECT SELECT int_col, double_col FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", diff --git a/datafusion/tests/sql/joins.rs b/datafusion/tests/sql/joins.rs index 04436ed460b16..36ad05755591d 100644 --- a/datafusion/tests/sql/joins.rs +++ b/datafusion/tests/sql/joins.rs @@ -20,7 +20,7 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn equijoin() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -35,11 +35,11 @@ async fn equijoin() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } - let mut ctx = create_join_context_qualified()?; + let ctx = create_join_context_qualified()?; let equivalent_sql = [ "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t1.a = t2.a ORDER BY t1.a", "SELECT t1.a, t2.b FROM t1 INNER JOIN t2 ON t2.a = t1.a ORDER BY t1.a", @@ -54,7 +54,7 @@ async fn equijoin() -> Result<()> { "+---+-----+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -62,7 +62,7 @@ async fn equijoin() -> Result<()> { #[tokio::test] async fn equijoin_multiple_condition_ordering() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t1_name <> t2_name ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name <> t1_name ORDER BY t1_id", @@ -79,7 +79,7 @@ async fn equijoin_multiple_condition_ordering() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -87,10 +87,10 @@ async fn equijoin_multiple_condition_ordering() -> Result<()> { #[tokio::test] async fn equijoin_and_other_condition() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -105,12 +105,12 @@ async fn equijoin_and_other_condition() -> Result<()> { #[tokio::test] async fn equijoin_left_and_condition_from_right() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id AND t2_name >= 'y' ORDER BY t1_id"; let res = ctx.create_logical_plan(sql); assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -128,12 +128,12 @@ async fn equijoin_left_and_condition_from_right() -> Result<()> { #[tokio::test] async fn equijoin_right_and_condition_from_left() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id AND t1_id >= 22 ORDER BY t2_name"; let res = ctx.create_logical_plan(sql); assert!(res.is_ok()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -163,7 +163,7 @@ async fn equijoin_and_unsupported_condition() -> Result<()> { #[tokio::test] async fn left_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -179,7 +179,7 @@ async fn left_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -188,7 +188,7 @@ async fn left_join() -> Result<()> { #[tokio::test] async fn left_join_unbalanced() -> Result<()> { // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id")?; + let ctx = create_join_context_unbalanced("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 LEFT JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -205,7 +205,7 @@ async fn left_join_unbalanced() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -216,7 +216,7 @@ async fn left_join_null_filter() -> Result<()> { // Since t2 is the non-preserved side of the join, we cannot push down a NULL filter. // Note that this is only true because IS NULL does not remove nulls. For filters that // remove nulls, we can rewrite the join as an inner join and then push down the filter. - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -229,7 +229,7 @@ async fn left_join_null_filter() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } @@ -237,7 +237,7 @@ async fn left_join_null_filter() -> Result<()> { #[tokio::test] async fn left_join_null_filter_on_join_column() -> Result<()> { // Again, since t2 is the non-preserved side of the join, we cannot push down a NULL filter. - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -249,14 +249,14 @@ async fn left_join_null_filter_on_join_column() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn left_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_name IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -268,14 +268,14 @@ async fn left_join_not_null_filter() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn left_join_not_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t2_id, t2_name FROM t1 LEFT JOIN t2 ON t1_id = t2_id WHERE t2_id IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+-------+---------+", @@ -288,14 +288,14 @@ async fn left_join_not_null_filter_on_join_column() -> Result<()> { "+-------+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -306,14 +306,14 @@ async fn right_join_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -323,14 +323,14 @@ async fn right_join_null_filter_on_join_column() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -342,14 +342,14 @@ async fn right_join_not_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join_not_null_filter_on_join_column() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 RIGHT JOIN t2 ON t1_id = t2_id WHERE t1_id IS NOT NULL ORDER BY t2_id"; let expected = vec![ "+-------+---------+-------+", @@ -362,14 +362,14 @@ async fn right_join_not_null_filter_on_join_column() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn full_join_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NULL ORDER BY t1_id"; let expected = vec![ "+-------+---------+-------+", @@ -381,14 +381,14 @@ async fn full_join_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn full_join_not_null_filter() -> Result<()> { - let mut ctx = create_join_context_with_nulls()?; + let ctx = create_join_context_with_nulls()?; let sql = "SELECT t1_id, t1_name, t2_id FROM t1 FULL OUTER JOIN t2 ON t1_id = t2_id WHERE t1_name IS NOT NULL ORDER BY t1_id"; let expected = vec![ "+-------+---------+-------+", @@ -402,14 +402,14 @@ async fn full_join_not_null_filter() -> Result<()> { "+-------+---------+-------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn right_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 RIGHT JOIN t2 ON t2_id = t1_id ORDER BY t1_id" @@ -425,7 +425,7 @@ async fn right_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -433,7 +433,7 @@ async fn right_join() -> Result<()> { #[tokio::test] async fn full_join() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1 FULL JOIN t2 ON t2_id = t1_id ORDER BY t1_id", @@ -450,7 +450,7 @@ async fn full_join() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -459,7 +459,7 @@ async fn full_join() -> Result<()> { "SELECT t1_id, t1_name, t2_name FROM t1 FULL OUTER JOIN t2 ON t2_id = t1_id ORDER BY t1_id", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -468,9 +468,9 @@ async fn full_join() -> Result<()> { #[tokio::test] async fn left_join_using() -> Result<()> { - let mut ctx = create_join_context("id", "id")?; + let ctx = create_join_context("id", "id")?; let sql = "SELECT id, t1_name, t2_name FROM t1 LEFT JOIN t2 USING (id) ORDER BY id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+---------+---------+", "| id | t1_name | t2_name |", @@ -487,7 +487,7 @@ async fn left_join_using() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let equivalent_sql = [ "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t1_id = t2_id ORDER BY t1_id", "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id", @@ -502,7 +502,7 @@ async fn equijoin_implicit_syntax() -> Result<()> { "+-------+---------+---------+", ]; for sql in equivalent_sql.iter() { - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -510,14 +510,14 @@ async fn equijoin_implicit_syntax() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax_with_filter() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name \ FROM t1, t2 \ WHERE t1_id > 0 \ AND t1_id = t2_id \ AND t2_id < 99 \ ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -533,10 +533,10 @@ async fn equijoin_implicit_syntax_with_filter() -> Result<()> { #[tokio::test] async fn equijoin_implicit_syntax_reversed() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1_id, t1_name, t2_name FROM t1, t2 WHERE t2_id = t1_id ORDER BY t1_id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -569,7 +569,7 @@ async fn cross_join() { let actual = execute(&mut ctx, sql).await; assert_eq!(4 * 4, actual.len()); - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -611,12 +611,12 @@ async fn cross_join() { #[tokio::test] async fn cross_join_unbalanced() { // the t1_id is larger than t2_id so the hash_build_probe_order optimizer should kick in - let mut ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); + let ctx = create_join_context_unbalanced("t1_id", "t2_id").unwrap(); // the order of the values is not determinisitic, so we need to sort to check the values let sql = "SELECT t1_id, t1_name, t2_name FROM t1 CROSS JOIN t2 ORDER BY t1_id, t1_name, t2_name"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+---------+---------+", "| t1_id | t1_name | t2_name |", @@ -648,7 +648,7 @@ async fn cross_join_unbalanced() { #[tokio::test] async fn test_join_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // register time table let timestamp_schema = Arc::new(Schema::new(vec![Field::new( @@ -673,7 +673,7 @@ async fn test_join_timestamp() -> Result<()> { JOIN (SELECT * FROM timestamp) as b \ ON a.time = b.time \ ORDER BY a.time"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+-------------------------------+", @@ -691,7 +691,7 @@ async fn test_join_timestamp() -> Result<()> { #[tokio::test] async fn test_join_float32() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -714,7 +714,7 @@ async fn test_join_float32() -> Result<()> { JOIN (SELECT * FROM population) as b \ ON a.population = b.population \ ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+------------+------+------------+", @@ -732,7 +732,7 @@ async fn test_join_float32() -> Result<()> { #[tokio::test] async fn test_join_float64() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); // register population table let population_schema = Arc::new(Schema::new(vec![ @@ -755,7 +755,7 @@ async fn test_join_float64() -> Result<()> { JOIN (SELECT * FROM population) as b \ ON a.population = b.population \ ORDER BY a.population"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+------------+------+------------+", @@ -799,8 +799,8 @@ async fn inner_join_qualified_names() -> Result<()> { ]; for sql in equivalent_sql.iter() { - let mut ctx = create_join_context_qualified()?; - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified()?; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } Ok(()) @@ -813,8 +813,8 @@ async fn inner_join_nulls() { let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; // left and right shouldn't match anything assert_batches_eq!(expected, &actual); @@ -857,14 +857,14 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul .unwrap(); let cities = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("countries", Arc::new(countries))?; ctx.register_table("cities", Arc::new(cities))?; // city.id is not in the on constraint, but the output result will contain both city.id and // country.id let sql = "SELECT t1.id, t2.id, t1.city, t2.country FROM cities AS t1 JOIN countries AS t2 ON t1.country_id = t2.id ORDER BY t1.id"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+-----------+---------+", "| id | id | city | country |", @@ -885,7 +885,7 @@ async fn join_tables_with_duplicated_column_name_not_in_on_constraint() -> Resul #[tokio::test] async fn join_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("t", table_with_timestamps()).unwrap(); let expected = vec![ @@ -899,7 +899,7 @@ async fn join_timestamp() -> Result<()> { ]; let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.nanos = t2.nanos", @@ -909,7 +909,7 @@ async fn join_timestamp() -> Result<()> { assert_batches_sorted_eq!(expected, &results); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.micros = t2.micros", @@ -919,7 +919,7 @@ async fn join_timestamp() -> Result<()> { assert_batches_sorted_eq!(expected, &results); let results = execute_to_batches( - &mut ctx, + &ctx, "SELECT * FROM t as t1 \ JOIN (SELECT * FROM t) as t2 \ ON t1.millis = t2.millis", diff --git a/datafusion/tests/sql/limit.rs b/datafusion/tests/sql/limit.rs index fd68e330bee18..3e7c1d6049d4f 100644 --- a/datafusion/tests/sql/limit.rs +++ b/datafusion/tests/sql/limit.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_limit() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+----+", "| c1 |", "+----+", "| c |", "| d |", "+----+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -30,10 +30,10 @@ async fn csv_query_limit() -> Result<()> { #[tokio::test] async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 200"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // println!("{}", pretty_format_batches(&a).unwrap()); let expected = vec![ "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", @@ -56,10 +56,10 @@ async fn csv_query_limit_bigger_than_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c2 FROM aggregate_test_100 LIMIT 100"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| c2 |", "+----+", "| 2 |", "| 5 |", "| 1 |", "| 1 |", "| 5 |", "| 4 |", "| 3 |", "| 3 |", "| 1 |", "| 4 |", "| 1 |", "| 4 |", "| 3 |", @@ -81,10 +81,10 @@ async fn csv_query_limit_with_same_nbr_of_rows() -> Result<()> { #[tokio::test] async fn csv_query_limit_zero() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 LIMIT 0"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["++", "++"]; assert_batches_eq!(expected, &actual); Ok(()) diff --git a/datafusion/tests/sql/mod.rs b/datafusion/tests/sql/mod.rs index a548d619d6357..cea85baace40b 100644 --- a/datafusion/tests/sql/mod.rs +++ b/datafusion/tests/sql/mod.rs @@ -45,7 +45,7 @@ use datafusion::{ error::{DataFusionError, Result}, physical_plan::ColumnarValue, }; -use datafusion::{execution::context::ExecutionContext, physical_plan::displayable}; +use datafusion::{execution::context::SessionContext, physical_plan::displayable}; /// A macro to assert that some particular line contains two substrings /// @@ -66,7 +66,7 @@ macro_rules! assert_metrics { macro_rules! test_expression { ($SQL:expr, $EXPECTED:expr) => { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let sql = format!("SELECT {}", $SQL); let actual = execute(&mut ctx, sql.as_str()).await; assert_eq!(actual[0][0], $EXPECTED); @@ -120,8 +120,8 @@ where } #[allow(clippy::unnecessary_wraps)] -fn create_ctx() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_ctx() -> Result { + let mut ctx = SessionContext::new(); // register a custom UDF ctx.register_udf(create_udf( @@ -150,8 +150,8 @@ fn custom_sqrt(args: &[ColumnarValue]) -> Result { } } -fn create_case_context() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_case_context() -> Result { + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new("c1", DataType::Utf8, true)])); let data = RecordBatch::try_new( schema.clone(), @@ -167,11 +167,8 @@ fn create_case_context() -> Result { Ok(ctx) } -fn create_join_context( - column_left: &str, - column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context(column_left: &str, column_right: &str) -> Result { + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -214,8 +211,8 @@ fn create_join_context( Ok(ctx) } -fn create_join_context_qualified() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context_qualified() -> Result { + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::UInt32, true), @@ -256,8 +253,8 @@ fn create_join_context_qualified() -> Result { fn create_join_context_unbalanced( column_left: &str, column_right: &str, -) -> Result { - let mut ctx = ExecutionContext::new(); +) -> Result { + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new(column_left, DataType::UInt32, true), @@ -302,8 +299,8 @@ fn create_join_context_unbalanced( } // Create memory tables with nulls -fn create_join_context_with_nulls() -> Result { - let mut ctx = ExecutionContext::new(); +fn create_join_context_with_nulls() -> Result { + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![ Field::new("t1_id", DataType::UInt32, true), @@ -405,7 +402,7 @@ fn get_tpch_table_schema(table: &str) -> Schema { } } -async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<()> { +async fn register_tpch_csv(ctx: &mut SessionContext, table: &str) -> Result<()> { let schema = get_tpch_table_schema(table); ctx.register_csv( @@ -417,7 +414,7 @@ async fn register_tpch_csv(ctx: &mut ExecutionContext, table: &str) -> Result<() Ok(()) } -async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { +async fn register_aggregate_csv_by_sql(ctx: &mut SessionContext) { let testdata = datafusion::test_util::arrow_test_data(); // TODO: The following c9 should be migrated to UInt32 and c10 should be UInt64 once @@ -459,7 +456,7 @@ async fn register_aggregate_csv_by_sql(ctx: &mut ExecutionContext) { } /// Create table "t1" with two boolean columns "a" and "b" -async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_boolean(ctx: &mut SessionContext) -> Result<()> { let a: BooleanArray = [ Some(true), Some(true), @@ -494,7 +491,7 @@ async fn register_boolean(ctx: &mut ExecutionContext) -> Result<()> { Ok(()) } -async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_aggregate_simple_csv(ctx: &mut SessionContext) -> Result<()> { // It's not possible to use aggregate_test_100, not enought similar values to test grouping on floats let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Float32, false), @@ -511,7 +508,7 @@ async fn register_aggregate_simple_csv(ctx: &mut ExecutionContext) -> Result<()> Ok(()) } -async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { +async fn register_aggregate_csv(ctx: &mut SessionContext) -> Result<()> { let testdata = datafusion::test_util::arrow_test_data(); let schema = test_util::aggr_test_schema(); ctx.register_csv( @@ -525,14 +522,14 @@ async fn register_aggregate_csv(ctx: &mut ExecutionContext) -> Result<()> { /// Execute SQL and return results as a RecordBatch async fn plan_and_collect( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await } /// Execute query and return results as a Vec of RecordBatches -async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec { +async fn execute_to_batches(ctx: &SessionContext, sql: &str) -> Vec { let msg = format!("Creating logical plan for '{}'", sql); let plan = ctx.create_logical_plan(sql).expect(&msg); let logical_schema = plan.schema(); @@ -545,8 +542,8 @@ async fn execute_to_batches(ctx: &mut ExecutionContext, sql: &str) -> Vec Vec Vec> { +async fn execute(ctx: &mut SessionContext, sql: &str) -> Vec> { result_vec(&execute_to_batches(ctx, sql).await) } @@ -601,7 +598,7 @@ fn result_vec(results: &[RecordBatch]) -> Vec> { result } -async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionContext) { +async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut SessionContext) { let df = ctx .sql( "CREATE EXTERNAL TABLE aggregate_simple ( @@ -623,7 +620,7 @@ async fn register_simple_aggregate_csv_with_decimal_by_sql(ctx: &mut ExecutionCo ); } -async fn register_alltypes_parquet(ctx: &mut ExecutionContext) { +async fn register_alltypes_parquet(ctx: &mut SessionContext) { let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "alltypes_plain", @@ -832,7 +829,7 @@ async fn nyc() -> Result<()> { Field::new("total_amount", DataType::Float64, true), ]); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_csv( "tripdata", "file.csv", diff --git a/datafusion/tests/sql/order.rs b/datafusion/tests/sql/order.rs index d23c817789510..c6613d581bc38 100644 --- a/datafusion/tests/sql/order.rs +++ b/datafusion/tests/sql/order.rs @@ -19,11 +19,11 @@ use super::*; #[tokio::test] async fn test_sort_unprojected_col() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // execute the query let sql = "SELECT id FROM alltypes_plain ORDER BY int_col, double_col"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| id |", "+----+", "| 4 |", "| 6 |", "| 2 |", "| 0 |", "| 5 |", "| 7 |", "| 3 |", "| 1 |", "+----+", @@ -34,10 +34,10 @@ async fn test_sort_unprojected_col() -> Result<()> { #[tokio::test] async fn test_order_by_agg_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------+", "| MIN(aggregate_test_100.c12) |", @@ -48,16 +48,16 @@ async fn test_order_by_agg_expr() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT MIN(c12) FROM aggregate_test_100 ORDER BY MIN(c12) + 0.1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn test_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -73,9 +73,9 @@ async fn test_nulls_first_asc() -> Result<()> { #[tokio::test] async fn test_nulls_first_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -91,9 +91,9 @@ async fn test_nulls_first_desc() -> Result<()> { #[tokio::test] async fn test_specific_nulls_last_desc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num DESC NULLS LAST"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", @@ -109,9 +109,9 @@ async fn test_specific_nulls_last_desc() -> Result<()> { #[tokio::test] async fn test_specific_nulls_first_asc() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT * FROM (VALUES (1, 'one'), (2, 'two'), (null, 'three')) AS t (num,letter) ORDER BY num ASC NULLS FIRST"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----+--------+", "| num | letter |", diff --git a/datafusion/tests/sql/parquet.rs b/datafusion/tests/sql/parquet.rs index 37912c8751c82..77949938cc7ab 100644 --- a/datafusion/tests/sql/parquet.rs +++ b/datafusion/tests/sql/parquet.rs @@ -24,12 +24,12 @@ use super::*; #[tokio::test] async fn parquet_query() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // NOTE that string_col is actually a binary column and does not have the UTF8 logical type // so we need an explicit cast let sql = "SELECT id, CAST(string_col AS varchar) FROM alltypes_plain"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-----------------------------------------+", "| id | CAST(alltypes_plain.string_col AS Utf8) |", @@ -50,7 +50,7 @@ async fn parquet_query() { #[tokio::test] async fn parquet_single_nan_schema() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet("single_nan", &format!("{}/single_nan.parquet", testdata)) .await @@ -59,8 +59,8 @@ async fn parquet_single_nan_schema() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = ctx.task_ctx(); + let results = collect(plan, task_ctx).await.unwrap(); for batch in results { assert_eq!(1, batch.num_rows()); assert_eq!(1, batch.num_columns()); @@ -70,7 +70,7 @@ async fn parquet_single_nan_schema() { #[tokio::test] #[ignore = "Test ignored, will be enabled as part of the nested Parquet reader"] async fn parquet_list_columns() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let testdata = datafusion::test_util::parquet_test_data(); ctx.register_parquet( "list_columns", @@ -96,8 +96,8 @@ async fn parquet_list_columns() { let plan = ctx.create_logical_plan(sql).unwrap(); let plan = ctx.optimize(&plan).unwrap(); let plan = ctx.create_physical_plan(&plan).await.unwrap(); - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect(plan, runtime).await.unwrap(); + let task_ctx = ctx.task_ctx(); + let results = collect(plan, task_ctx).await.unwrap(); // int64_list utf8_list // 0 [1, 2, 3] [abc, efg, hij] @@ -212,7 +212,7 @@ async fn schema_merge_ignores_metadata() { // Read the parquet files into a dataframe to confirm results // (no errors) - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let df = ctx .read_parquet(table_dir.to_str().unwrap().to_string()) .await diff --git a/datafusion/tests/sql/partitioned_csv.rs b/datafusion/tests/sql/partitioned_csv.rs index 3394887ad0b83..e90b7a0beffa9 100644 --- a/datafusion/tests/sql/partitioned_csv.rs +++ b/datafusion/tests/sql/partitioned_csv.rs @@ -25,13 +25,13 @@ use arrow::{ }; use datafusion::{ error::Result, - prelude::{CsvReadOptions, ExecutionConfig, ExecutionContext}, + prelude::{CsvReadOptions, SessionConfig, SessionContext}, }; use tempfile::TempDir; /// Execute SQL and return results async fn plan_and_collect( - ctx: &mut ExecutionContext, + ctx: &mut SessionContext, sql: &str, ) -> Result> { ctx.sql(sql).await?.collect().await @@ -77,9 +77,9 @@ fn populate_csv_partitions( pub async fn create_ctx( tmp_dir: &TempDir, partition_count: usize, -) -> Result { +) -> Result { let mut ctx = - ExecutionContext::with_config(ExecutionConfig::new().with_target_partitions(8)); + SessionContext::with_config(SessionConfig::new().with_target_partitions(8)); let schema = populate_csv_partitions(tmp_dir, partition_count, ".csv")?; diff --git a/datafusion/tests/sql/predicates.rs b/datafusion/tests/sql/predicates.rs index f4e1f4f4deef9..879107c84e946 100644 --- a/datafusion/tests/sql/predicates.rs +++ b/datafusion/tests/sql/predicates.rs @@ -19,10 +19,10 @@ use super::*; #[tokio::test] async fn csv_query_with_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, c12 FROM aggregate_test_100 WHERE c12 > 0.376 AND c12 < 0.4"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+---------------------+", "| c1 | c12 |", @@ -37,10 +37,10 @@ async fn csv_query_with_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negative_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1, c4 FROM aggregate_test_100 WHERE c3 < -55 AND -c4 > 30000"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+--------+", "| c1 | c4 |", @@ -55,10 +55,10 @@ async fn csv_query_with_negative_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_negated_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE NOT(c1 != 'a')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -72,10 +72,10 @@ async fn csv_query_with_negated_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_not_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -89,10 +89,10 @@ async fn csv_query_with_is_not_null_predicate() -> Result<()> { #[tokio::test] async fn csv_query_with_is_null_predicate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT COUNT(1) FROM aggregate_test_100 WHERE c1 IS NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -106,12 +106,12 @@ async fn csv_query_with_is_null_predicate() -> Result<()> { #[tokio::test] async fn query_where_neg_num() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; // Negative numbers do not parse correctly as of Arrow 2.0.0 let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+-------+", "| c7 | c8 |", @@ -127,18 +127,18 @@ async fn query_where_neg_num() -> Result<()> { // Also check floating point neg numbers let sql = "select c7, c8 from aggregate_test_100 where c7 >= -2.9 and c7 < 10"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } #[tokio::test] async fn like() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv_by_sql(&mut ctx).await; let sql = "SELECT COUNT(c1) FROM aggregate_test_100 WHERE c13 LIKE '%FB%'"; // check that the physical and logical schemas are equal - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------+", "| COUNT(aggregate_test_100.c1) |", @@ -152,10 +152,10 @@ async fn like() -> Result<()> { #[tokio::test] async fn csv_between_expr() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 BETWEEN 0.995 AND 1.0"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c4 |", @@ -169,10 +169,10 @@ async fn csv_between_expr() -> Result<()> { #[tokio::test] async fn csv_between_expr_negated() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c4 FROM aggregate_test_100 WHERE c12 NOT BETWEEN 0 AND 0.995"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c4 |", @@ -193,11 +193,11 @@ async fn like_on_strings() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -220,11 +220,11 @@ async fn like_on_string_dictionaries() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 LIKE '%a%'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -247,11 +247,11 @@ async fn test_regexp_is_match() -> Result<()> { let batch = RecordBatch::try_from_iter(vec![("c1", Arc::new(input) as _)]).unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT * FROM test WHERE c1 ~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -262,7 +262,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 ~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -274,7 +274,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 !~ 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -287,7 +287,7 @@ async fn test_regexp_is_match() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT * FROM test WHERE c1 !~* 'z'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| c1 |", @@ -313,8 +313,8 @@ async fn except_with_null_not_equal() { "+-----+-----+", ]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } @@ -325,19 +325,19 @@ async fn except_with_null_equal() { EXCEPT SELECT * FROM (SELECT null AS id1, 1 AS id2) t2"; let expected = vec!["++", "++"]; - let mut ctx = create_join_context_qualified().unwrap(); - let actual = execute_to_batches(&mut ctx, sql).await; + let ctx = create_join_context_qualified().unwrap(); + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); } #[tokio::test] async fn test_expect_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT ALL SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", @@ -354,11 +354,11 @@ async fn test_expect_all() -> Result<()> { #[tokio::test] async fn test_expect_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_alltypes_parquet(&mut ctx).await; // execute the query let sql = "SELECT int_col, double_col FROM alltypes_plain where int_col > 0 EXCEPT SELECT int_col, double_col FROM alltypes_plain where int_col < 1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+------------+", "| int_col | double_col |", diff --git a/datafusion/tests/sql/projection.rs b/datafusion/tests/sql/projection.rs index 0a956a9411eb1..d2bcfbcd6d635 100644 --- a/datafusion/tests/sql/projection.rs +++ b/datafusion/tests/sql/projection.rs @@ -22,10 +22,10 @@ use super::*; #[tokio::test] async fn projection_same_fields() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select (1+1) as a from (select 1 as a) as b;"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| a |", "+---+", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); @@ -35,13 +35,13 @@ async fn projection_same_fields() -> Result<()> { #[tokio::test] async fn projection_type_alias() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; // Query that aliases one column to the name of a different column // that also has a different type (c1 == float32, c3 == boolean) let sql = "SELECT c1 as c3 FROM aggregate_simple ORDER BY c3 LIMIT 2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", @@ -58,10 +58,10 @@ async fn projection_type_alias() -> Result<()> { #[tokio::test] async fn csv_query_group_by_avg_with_projection() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT avg(c12), c1 FROM aggregate_test_100 GROUP BY c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------------------+----+", "| AVG(aggregate_test_100.c12) | c1 |", @@ -139,7 +139,6 @@ async fn projection_on_table_scan() -> Result<()> { let tmp_dir = TempDir::new()?; let partition_count = 4; let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - let runtime = ctx.state.lock().runtime_env.clone(); let table = ctx.table("test")?; let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan()) @@ -170,8 +169,8 @@ async fn projection_on_table_scan() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("c2", physical_plan.schema().field(0).name().as_str()); - - let batches = collect(physical_plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let batches = collect(physical_plan, task_ctx).await?; assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::()); Ok(()) @@ -218,7 +217,7 @@ async fn projection_on_memory_scan() -> Result<()> { .build()?; assert_fields_eq(&plan, vec!["b"]); - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let optimized_plan = ctx.optimize(&plan)?; match &optimized_plan { LogicalPlan::Projection(Projection { input, .. }) => match &**input { @@ -247,8 +246,8 @@ async fn projection_on_memory_scan() -> Result<()> { assert_eq!(1, physical_plan.schema().fields().len()); assert_eq!("b", physical_plan.schema().field(0).name().as_str()); - let runtime = ctx.state.lock().runtime_env.clone(); - let batches = collect(physical_plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let batches = collect(physical_plan, task_ctx).await?; assert_eq!(1, batches.len()); assert_eq!(1, batches[0].num_columns()); assert_eq!(4, batches[0].num_rows()); diff --git a/datafusion/tests/sql/references.rs b/datafusion/tests/sql/references.rs index 779c6a3366732..ec658d5340637 100644 --- a/datafusion/tests/sql/references.rs +++ b/datafusion/tests/sql/references.rs @@ -19,7 +19,7 @@ use super::*; #[tokio::test] async fn qualified_table_references() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; for table_ref in &[ @@ -28,7 +28,7 @@ async fn qualified_table_references() -> Result<()> { "datafusion.public.aggregate_test_100", ] { let sql = format!("SELECT COUNT(*) FROM {}", table_ref); - let actual = execute_to_batches(&mut ctx, &sql).await; + let actual = execute_to_batches(&ctx, &sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -43,7 +43,7 @@ async fn qualified_table_references() -> Result<()> { #[tokio::test] async fn qualified_table_references_and_fields() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let c1: StringArray = vec!["foofoo", "foobar", "foobaz"] .into_iter() @@ -73,7 +73,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // however, enclosing it in double quotes is ok let sql = r#"SELECT "f.c1" from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------+", "| f.c1 |", @@ -86,12 +86,12 @@ async fn qualified_table_references_and_fields() -> Result<()> { assert_batches_eq!(expected, &actual); // Works fully qualified too let sql = r#"SELECT test."f.c1" from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); // check that duplicated table name and column name are ok let sql = r#"SELECT "test.c2" as expr1, test."test.c2" as expr2 from test"#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-------+", "| expr1 | expr2 |", @@ -107,7 +107,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { // datafusion should run the query, not that someone should write // this let sql = r#"SELECT "....", "...." as c3 from test order by "....""#; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+----+", "| .... | c3 |", @@ -123,7 +123,7 @@ async fn qualified_table_references_and_fields() -> Result<()> { #[tokio::test] async fn test_partial_qualified_name() -> Result<()> { - let mut ctx = create_join_context("t1_id", "t2_id")?; + let ctx = create_join_context("t1_id", "t2_id")?; let sql = "SELECT t1.t1_id, t1_name FROM public.t1"; let expected = vec![ "+-------+---------+", @@ -135,7 +135,7 @@ async fn test_partial_qualified_name() -> Result<()> { "| 44 | d |", "+-------+---------+", ]; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; assert_batches_eq!(expected, &actual); Ok(()) } diff --git a/datafusion/tests/sql/select.rs b/datafusion/tests/sql/select.rs index 6ba190856a46a..54120e023be9b 100644 --- a/datafusion/tests/sql/select.rs +++ b/datafusion/tests/sql/select.rs @@ -24,12 +24,12 @@ use tempfile::TempDir; #[tokio::test] async fn all_where_empty() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT * FROM aggregate_test_100 WHERE 1=2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["++", "++"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -37,10 +37,10 @@ async fn all_where_empty() -> Result<()> { #[tokio::test] async fn select_values_list() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); { let sql = "VALUES (1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -52,7 +52,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (-1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -64,7 +64,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (2+1,2-1,2>1)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+---------+", "| column1 | column2 | column3 |", @@ -86,7 +86,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1),(2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", "| column1 |", @@ -104,7 +104,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(2,'b')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -137,7 +137,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -151,7 +151,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (NULL,'a'),(NULL,'b'),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -165,7 +165,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (NULL,'a'),(NULL,'b'),(NULL,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -179,7 +179,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,'a'),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -193,7 +193,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,NULL),(2,NULL),(3,'c')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+", "| column1 | column2 |", @@ -207,7 +207,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "VALUES (1,2,3,4,5,6,7,8,9,10,11,12,13,NULL,'F',3.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+---------+---------+---------+---------+---------+---------+---------+---------+----------+----------+----------+----------+----------+----------+----------+", "| column1 | column2 | column3 | column4 | column5 | column6 | column7 | column8 | column9 | column10 | column11 | column12 | column13 | column14 | column15 | column16 |", @@ -219,7 +219,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "SELECT * FROM (VALUES (1,'a'),(2,NULL)) AS t(c1, c2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+", "| c1 | c2 |", @@ -232,7 +232,7 @@ async fn select_values_list() -> Result<()> { } { let sql = "EXPLAIN VALUES (1, 'a', -1, 1.1),(NULL, 'b', -3, 0.5)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------+-----------------------------------------------------------------------------------------------------------+", "| plan_type | plan |", @@ -249,14 +249,14 @@ async fn select_values_list() -> Result<()> { #[tokio::test] async fn select_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "SELECT c1 FROM aggregate_simple order by c1"; - let results = execute_to_batches(&mut ctx, sql).await; + let results = execute_to_batches(&ctx, sql).await; let sql_all = "SELECT ALL c1 FROM aggregate_simple order by c1"; - let results_all = execute_to_batches(&mut ctx, sql_all).await; + let results_all = execute_to_batches(&ctx, sql_all).await; let expected = vec![ "+---------+", @@ -288,7 +288,7 @@ async fn select_all() -> Result<()> { #[tokio::test] async fn select_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await?; let sql = "SELECT DISTINCT * FROM aggregate_simple"; @@ -305,11 +305,11 @@ async fn select_distinct() -> Result<()> { #[tokio::test] async fn select_distinct_simple_1() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await.unwrap(); let sql = "SELECT DISTINCT c1 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+", @@ -327,11 +327,11 @@ async fn select_distinct_simple_1() { #[tokio::test] async fn select_distinct_simple_2() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await.unwrap(); let sql = "SELECT DISTINCT c1, c2 FROM aggregate_simple order by c1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------+----------------+", @@ -349,11 +349,11 @@ async fn select_distinct_simple_2() { #[tokio::test] async fn select_distinct_simple_3() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await.unwrap(); let sql = "SELECT distinct c3 FROM aggregate_simple order by c3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", @@ -368,11 +368,11 @@ async fn select_distinct_simple_3() { #[tokio::test] async fn select_distinct_simple_4() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_simple_csv(&mut ctx).await.unwrap(); let sql = "SELECT distinct c1+c2 as a FROM aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", @@ -390,7 +390,7 @@ async fn select_distinct_simple_4() { #[tokio::test] async fn select_distinct_from() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select 1 IS DISTINCT FROM CAST(NULL as INT) as a, @@ -400,7 +400,7 @@ async fn select_distinct_from() { NULL IS DISTINCT FROM NULL as e, NULL IS NOT DISTINCT FROM NULL as f "; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+-------+-------+------+-------+------+", "| a | b | c | d | e | f |", @@ -413,7 +413,7 @@ async fn select_distinct_from() { #[tokio::test] async fn select_distinct_from_utf8() { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "select 'x' IS DISTINCT FROM NULL as a, @@ -421,7 +421,7 @@ async fn select_distinct_from_utf8() { 'x' IS NOT DISTINCT FROM NULL as c, 'x' IS NOT DISTINCT FROM 'x' as d "; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------+-------+-------+------+", "| a | b | c | d |", @@ -434,10 +434,10 @@ async fn select_distinct_from_utf8() { #[tokio::test] async fn csv_query_with_decimal_by_sql() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_simple_aggregate_csv_with_decimal_by_sql(&mut ctx).await; let sql = "SELECT c1 from aggregate_simple"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| c1 |", @@ -465,10 +465,10 @@ async fn csv_query_with_decimal_by_sql() -> Result<()> { #[tokio::test] async fn use_between_expression_in_select_query() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let sql = "SELECT 1 NOT BETWEEN 3 AND 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------------------+", "| Int64(1) NOT BETWEEN Int64(3) AND Int64(5) |", @@ -484,7 +484,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { ctx.register_table("test", Arc::new(table))?; let sql = "SELECT abs(c1) BETWEEN 0 AND LoG(c1 * 100 ) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; // Expect field name to be correctly converted for expr, low and high. let expected = vec![ "+--------------------------------------------------------------------+", @@ -499,7 +499,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "EXPLAIN SELECT c1 BETWEEN 2 AND 3 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let formatted = arrow::util::pretty::pretty_format_batches(&actual) .unwrap() .to_string(); @@ -515,7 +515,7 @@ async fn use_between_expression_in_select_query() -> Result<()> { #[tokio::test] async fn query_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![Field::new( "some_list", DataType::List(Box::new(Field::new("item", DataType::Int64, true))), @@ -539,7 +539,7 @@ async fn query_get_indexed_field() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 7 |", "+----+", ]; @@ -549,7 +549,7 @@ async fn query_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_list": [[i64]] } let schema = Arc::new(Schema::new(vec![Field::new( @@ -585,7 +585,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_list[0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| i0 |", @@ -597,7 +597,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { ]; assert_batches_eq!(expected, &actual); let sql = "SELECT some_list[0][0] as i0 FROM ints LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 5 |", "| 11 |", "+----+", ]; @@ -607,7 +607,7 @@ async fn query_nested_get_indexed_field() -> Result<()> { #[tokio::test] async fn query_nested_get_indexed_field_on_struct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let nested_dt = DataType::List(Box::new(Field::new("item", DataType::Int64, true))); // Nested schema of { "some_struct": { "bar": [i64] } } let struct_fields = vec![Field::new("bar", nested_dt.clone(), true)]; @@ -635,7 +635,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT some_struct[\"bar\"] as l0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| l0 |", @@ -647,7 +647,7 @@ async fn query_nested_get_indexed_field_on_struct() -> Result<()> { ]; assert_batches_eq!(expected, &actual); let sql = "SELECT some_struct[\"bar\"][0] as i0 FROM structs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+", "| i0 |", "+----+", "| 0 |", "| 4 |", "| 8 |", "+----+", ]; @@ -676,12 +676,12 @@ async fn query_on_string_dictionary() -> Result<()> { .unwrap(); let table = MemTable::try_new(batch.schema(), vec![vec![batch]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; // Basic SELECT let sql = "SELECT d1 FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -695,7 +695,7 @@ async fn query_on_string_dictionary() -> Result<()> { // basic filtering let sql = "SELECT d1 FROM test WHERE d1 IS NOT NULL"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -708,7 +708,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -720,7 +720,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -732,7 +732,7 @@ async fn query_on_string_dictionary() -> Result<()> { // order comparison with another dictionary column let sql = "SELECT d1 FROM test WHERE d1 <= d2"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -744,7 +744,7 @@ async fn query_on_string_dictionary() -> Result<()> { // comparison with a non dictionary column let sql = "SELECT d1 FROM test WHERE d1 = d3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -756,7 +756,7 @@ async fn query_on_string_dictionary() -> Result<()> { // filtering with constant let sql = "SELECT d1 FROM test WHERE d1 = 'three'"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+", "| d1 |", @@ -768,7 +768,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation let sql = "SELECT concat(d1, '-foo') FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+------------------------------+", "| concat(test.d1,Utf8(\"-foo\")) |", @@ -782,7 +782,7 @@ async fn query_on_string_dictionary() -> Result<()> { // Expression evaluation with two dictionaries let sql = "SELECT concat(d1, d2) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| concat(test.d1,test.d2) |", @@ -796,7 +796,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation let sql = "SELECT COUNT(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------+", "| COUNT(test.d1) |", @@ -808,7 +808,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation min let sql = "SELECT MIN(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+", "| MIN(test.d1) |", @@ -820,7 +820,7 @@ async fn query_on_string_dictionary() -> Result<()> { // aggregation max let sql = "SELECT MAX(d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------+", "| MAX(test.d1) |", @@ -832,7 +832,7 @@ async fn query_on_string_dictionary() -> Result<()> { // grouping let sql = "SELECT d1, COUNT(*) FROM test group by d1"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+-----------------+", "| d1 | COUNT(UInt8(1)) |", @@ -846,7 +846,7 @@ async fn query_on_string_dictionary() -> Result<()> { // window functions let sql = "SELECT d1, row_number() OVER (partition by d1) FROM test"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------+--------------+", "| d1 | ROW_NUMBER() |", @@ -865,11 +865,11 @@ async fn query_on_string_dictionary() -> Result<()> { async fn query_cte() -> Result<()> { // Test for SELECT without FROM. // Should evaluate expressions in project position. - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); // simple with let sql = "WITH t AS (SELECT 1) SELECT * FROM t"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| Int64(1) |", @@ -882,19 +882,19 @@ async fn query_cte() -> Result<()> { // with + union let sql = "WITH t AS (SELECT 1 AS a), u AS (SELECT 2 AS a) SELECT * FROM t UNION ALL SELECT * FROM u"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| a |", "+---+", "| 1 |", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); // with + join let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT 1 AS id2, 5 as x) SELECT x FROM t JOIN u ON (id1 = id2)"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 5 |", "+---+"]; assert_batches_eq!(expected, &actual); // backward reference let sql = "WITH t AS (SELECT 1 AS id1), u AS (SELECT * FROM t) SELECT * from u"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+-----+", "| id1 |", "+-----+", "| 1 |", "+-----+"]; assert_batches_eq!(expected, &actual); @@ -903,7 +903,7 @@ async fn query_cte() -> Result<()> { #[tokio::test] async fn csv_select_nested() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT o1, o2, c3 FROM ( @@ -915,7 +915,7 @@ async fn csv_select_nested() -> Result<()> { ORDER BY c2 ASC, c3 ASC ) AS a ) AS b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----+----+------+", "| o1 | o2 | c3 |", @@ -945,8 +945,8 @@ async fn parallel_query_with_filter() -> Result<()> { let physical_plan = ctx.create_physical_plan(&logical_plan).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - let results = collect_partitioned(physical_plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let results = collect_partitioned(physical_plan, task_ctx).await?; // note that the order of partitions is not deterministic let mut num_rows = 0; @@ -991,7 +991,7 @@ async fn parallel_query_with_filter() -> Result<()> { #[tokio::test] async fn query_empty_table() { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let empty_table = Arc::new(EmptyTable::new(Arc::new(Schema::empty()))); ctx.register_table("test_tbl", empty_table).unwrap(); let sql = "SELECT * FROM test_tbl"; diff --git a/datafusion/tests/sql/timestamp.rs b/datafusion/tests/sql/timestamp.rs index 42aa3f4501631..4e0e7a8c79a89 100644 --- a/datafusion/tests/sql/timestamp.rs +++ b/datafusion/tests/sql/timestamp.rs @@ -20,7 +20,7 @@ use datafusion::from_slice::FromSlice; #[tokio::test] async fn query_cast_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -35,7 +35,7 @@ async fn query_cast_timestamp_millis() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_millis(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------+", @@ -52,7 +52,7 @@ async fn query_cast_timestamp_millis() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -67,7 +67,7 @@ async fn query_cast_timestamp_micros() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_micros(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------+", @@ -85,7 +85,7 @@ async fn query_cast_timestamp_micros() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let t1_schema = Arc::new(Schema::new(vec![Field::new("ts", DataType::Int64, true)])); let t1_data = RecordBatch::try_new( @@ -98,7 +98,7 @@ async fn query_cast_timestamp_seconds() -> Result<()> { ctx.register_table("t1", Arc::new(t1_table))?; let sql = "SELECT to_timestamp_seconds(ts) FROM t1 LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------+", @@ -116,12 +116,12 @@ async fn query_cast_timestamp_seconds() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_nanos_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; // Original column is nanos, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", @@ -135,7 +135,7 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT to_timestamp_micros(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", @@ -149,7 +149,7 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { assert_batches_eq!(expected, &actual); let sql = "SELECT to_timestamp_seconds(ts) FROM ts_data LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+--------------------------------+", "| totimestampseconds(ts_data.ts) |", @@ -166,12 +166,12 @@ async fn query_cast_timestamp_nanos_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_seconds_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("ts_secs", make_timestamp_table::()?)?; // Original column is seconds, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| totimestampmillis(ts_secs.ts) |", @@ -186,7 +186,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { // Original column is seconds, convert to micros and check timestamp let sql = "SELECT to_timestamp_micros(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------------+", "| totimestampmicros(ts_secs.ts) |", @@ -200,7 +200,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { // to nanos let sql = "SELECT to_timestamp(ts) FROM ts_secs LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+", "| totimestamp(ts_secs.ts) |", @@ -216,7 +216,7 @@ async fn query_cast_timestamp_seconds_to_others() -> Result<()> { #[tokio::test] async fn query_cast_timestamp_micros_to_others() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table( "ts_micros", make_timestamp_table::()?, @@ -224,7 +224,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to millis and check timestamp let sql = "SELECT to_timestamp_millis(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------------------+", "| totimestampmillis(ts_micros.ts) |", @@ -238,7 +238,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to seconds and check timestamp let sql = "SELECT to_timestamp_seconds(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------------+", "| totimestampseconds(ts_micros.ts) |", @@ -252,7 +252,7 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { // Original column is micros, convert to nanos and check timestamp let sql = "SELECT to_timestamp(ts) FROM ts_micros LIMIT 3"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", "| totimestamp(ts_micros.ts) |", @@ -268,11 +268,11 @@ async fn query_cast_timestamp_micros_to_others() -> Result<()> { #[tokio::test] async fn to_timestamp() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -287,14 +287,14 @@ async fn to_timestamp() -> Result<()> { #[tokio::test] async fn to_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, )?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_millis('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", "| COUNT(UInt8(1)) |", @@ -308,14 +308,14 @@ async fn to_timestamp_millis() -> Result<()> { #[tokio::test] async fn to_timestamp_micros() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table( "ts_data", make_timestamp_table::()?, )?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_micros('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -330,11 +330,11 @@ async fn to_timestamp_micros() -> Result<()> { #[tokio::test] async fn to_timestamp_seconds() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_table::()?)?; let sql = "SELECT COUNT(*) FROM ts_data where ts > to_timestamp_seconds('2020-09-08T12:00:00+00:00')"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------------+", @@ -349,11 +349,11 @@ async fn to_timestamp_seconds() -> Result<()> { #[tokio::test] async fn count_distinct_timestamps() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("ts_data", make_timestamp_nano_table()?)?; let sql = "SELECT COUNT(DISTINCT(ts)) FROM ts_data"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+", @@ -369,7 +369,7 @@ async fn count_distinct_timestamps() -> Result<()> { #[tokio::test] async fn test_current_timestamp_expressions() -> Result<()> { let t1 = chrono::Utc::now().timestamp(); - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let actual = execute(&mut ctx, "SELECT NOW(), NOW() as t2").await; let res1 = actual[0][0].as_str(); let res2 = actual[0][1].as_str(); @@ -387,7 +387,7 @@ async fn test_current_timestamp_expressions() -> Result<()> { #[tokio::test] async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let t1 = chrono::Utc::now().timestamp(); - let ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT NOW(), NOW() as t2"; let msg = format!("Creating logical plan for '{}'", sql); @@ -397,8 +397,8 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { let plan = ctx.create_physical_plan(&plan).await.expect(&msg); let msg = format!("Executing physical plan for '{}': {:?}", sql, plan); - let runtime = ctx.state.lock().runtime_env.clone(); - let res = collect(plan, runtime).await.expect(&msg); + let task_ctx = ctx.task_ctx(); + let res = collect(plan, task_ctx).await.expect(&msg); let actual = result_vec(&res); let res1 = actual[0][0].as_str(); @@ -416,7 +416,7 @@ async fn test_current_timestamp_expressions_non_optimized() -> Result<()> { #[tokio::test] async fn timestamp_minmax() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(None)?; let table_b = make_timestamp_tz_table::(Some("UTC".to_owned()))?; @@ -424,7 +424,7 @@ async fn timestamp_minmax() -> Result<()> { ctx.register_table("table_b", table_b)?; let sql = "SELECT MIN(table_a.ts), MAX(table_b.ts) FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+", "| MIN(table_a.ts) | MAX(table_b.ts) |", @@ -440,7 +440,7 @@ async fn timestamp_minmax() -> Result<()> { #[tokio::test] async fn timestamp_coercion() -> Result<()> { { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_tz_table::(Some("UTC".to_owned()))?; let table_b = @@ -449,7 +449,7 @@ async fn timestamp_coercion() -> Result<()> { ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -469,14 +469,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -496,14 +496,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -523,14 +523,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -550,14 +550,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -577,14 +577,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -604,14 +604,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -631,14 +631,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -658,14 +658,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -685,14 +685,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+---------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -712,14 +712,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+-------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -739,14 +739,14 @@ async fn timestamp_coercion() -> Result<()> { } { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let table_a = make_timestamp_table::()?; let table_b = make_timestamp_table::()?; ctx.register_table("table_a", table_a)?; ctx.register_table("table_b", table_b)?; let sql = "SELECT table_a.ts, table_b.ts, table_a.ts = table_b.ts FROM table_a, table_b"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------------------------+----------------------------+--------------------------+", "| ts | ts | table_a.ts Eq table_b.ts |", @@ -770,7 +770,7 @@ async fn timestamp_coercion() -> Result<()> { #[tokio::test] async fn group_by_timestamp_millis() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let schema = Arc::new(Schema::new(vec![ Field::new( @@ -802,7 +802,7 @@ async fn group_by_timestamp_millis() -> Result<()> { let sql = "SELECT timestamp, SUM(count) FROM t1 GROUP BY timestamp ORDER BY timestamp ASC"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+---------------------+---------------+", "| timestamp | SUM(t1.count) |", diff --git a/datafusion/tests/sql/udf.rs b/datafusion/tests/sql/udf.rs index 6b714cb368b8c..c35c1d54148ef 100644 --- a/datafusion/tests/sql/udf.rs +++ b/datafusion/tests/sql/udf.rs @@ -51,7 +51,7 @@ async fn scalar_udf() -> Result<()> { ], )?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch]])?; ctx.register_table("t", Arc::new(provider))?; @@ -97,8 +97,8 @@ async fn scalar_udf() -> Result<()> { let plan = ctx.optimize(&plan)?; let plan = ctx.create_physical_plan(&plan).await?; - let runtime = ctx.state.lock().runtime_env.clone(); - let result = collect(plan, runtime).await?; + let task_ctx = ctx.task_ctx(); + let result = collect(plan, task_ctx).await?; let expected = vec![ "+-----+-----+-----------------+", @@ -155,7 +155,7 @@ async fn simple_udaf() -> Result<()> { vec![Arc::new(Int32Array::from_slice(&[4, 5]))], )?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let provider = MemTable::try_new(Arc::new(schema), vec![vec![batch1], vec![batch2]])?; ctx.register_table("t", Arc::new(provider))?; diff --git a/datafusion/tests/sql/unicode.rs b/datafusion/tests/sql/unicode.rs index 55747f2a9ac4e..f9cb4c482989c 100644 --- a/datafusion/tests/sql/unicode.rs +++ b/datafusion/tests/sql/unicode.rs @@ -116,7 +116,7 @@ async fn generic_query_length>>( let table = MemTable::try_new(schema, vec![vec![data]])?; - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_table("test", Arc::new(table))?; let sql = "SELECT length(c1) FROM test"; let actual = execute(&mut ctx, sql).await; diff --git a/datafusion/tests/sql/union.rs b/datafusion/tests/sql/union.rs index a1f81d24f4566..958e4b9d9389f 100644 --- a/datafusion/tests/sql/union.rs +++ b/datafusion/tests/sql/union.rs @@ -19,9 +19,9 @@ use super::*; #[tokio::test] async fn union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 as x UNION ALL SELECT 2 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "| 2 |", "+---+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -29,7 +29,7 @@ async fn union_all() -> Result<()> { #[tokio::test] async fn csv_union_all() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "SELECT c1 FROM aggregate_test_100 UNION ALL SELECT c1 FROM aggregate_test_100"; @@ -40,9 +40,9 @@ async fn csv_union_all() -> Result<()> { #[tokio::test] async fn union_distinct() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT 1 as x UNION SELECT 1 as x"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec!["+---+", "| x |", "+---+", "| 1 |", "+---+"]; assert_batches_eq!(expected, &actual); Ok(()) @@ -50,10 +50,10 @@ async fn union_distinct() -> Result<()> { #[tokio::test] async fn union_all_with_aggregate() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let ctx = SessionContext::new(); let sql = "SELECT SUM(d) FROM (SELECT 1 as c, 2 as d UNION ALL SELECT 1 as c, 3 AS d) as a"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+----------+", "| SUM(a.d) |", diff --git a/datafusion/tests/sql/window.rs b/datafusion/tests/sql/window.rs index 321ab320f5be7..1b7335b122333 100644 --- a/datafusion/tests/sql/window.rs +++ b/datafusion/tests/sql/window.rs @@ -20,7 +20,7 @@ use super::*; /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_empty_over() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "select \ c9, \ @@ -30,7 +30,7 @@ async fn csv_query_window_with_empty_over() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+------------------------------+----------------------------+----------------------------+", "| c9 | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) |", @@ -49,7 +49,7 @@ async fn csv_query_window_with_empty_over() -> Result<()> { /// for window functions without order by the first, last, and nth function call does not make sense #[tokio::test] async fn csv_query_window_with_partition_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "select \ c9, \ @@ -61,7 +61,7 @@ async fn csv_query_window_with_partition_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+-------------------------------------------+-------------------------------------------+---------------------------------------------+-------------------------------------------+-------------------------------------------+", "| c9 | SUM(CAST(aggregate_test_100.c4 AS Int32)) | AVG(CAST(aggregate_test_100.c4 AS Int32)) | COUNT(CAST(aggregate_test_100.c4 AS Int32)) | MAX(CAST(aggregate_test_100.c4 AS Int32)) | MIN(CAST(aggregate_test_100.c4 AS Int32)) |", @@ -79,7 +79,7 @@ async fn csv_query_window_with_partition_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "select \ c9, \ @@ -94,7 +94,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", @@ -112,7 +112,7 @@ async fn csv_query_window_with_order_by() -> Result<()> { #[tokio::test] async fn csv_query_window_with_partition_by_order_by() -> Result<()> { - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx).await?; let sql = "select \ c9, \ @@ -127,7 +127,7 @@ async fn csv_query_window_with_partition_by_order_by() -> Result<()> { from aggregate_test_100 \ order by c9 \ limit 5"; - let actual = execute_to_batches(&mut ctx, sql).await; + let actual = execute_to_batches(&ctx, sql).await; let expected = vec![ "+-----------+----------------------------+----------------------------+------------------------------+----------------------------+----------------------------+------------------------------------+-----------------------------------+-------------------------------------------+", "| c9 | SUM(aggregate_test_100.c5) | AVG(aggregate_test_100.c5) | COUNT(aggregate_test_100.c5) | MAX(aggregate_test_100.c5) | MIN(aggregate_test_100.c5) | FIRST_VALUE(aggregate_test_100.c5) | LAST_VALUE(aggregate_test_100.c5) | NTH_VALUE(aggregate_test_100.c5,Int64(2)) |", diff --git a/datafusion/tests/statistics.rs b/datafusion/tests/statistics.rs index c5fba894e6860..76b185e2c2a31 100644 --- a/datafusion/tests/statistics.rs +++ b/datafusion/tests/statistics.rs @@ -29,12 +29,12 @@ use datafusion::{ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, - prelude::ExecutionContext, + prelude::SessionContext, scalar::ScalarValue, }; use async_trait::async_trait; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::TaskContext; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -144,7 +144,7 @@ impl ExecutionPlan for StatisticsValidation { async fn execute( &self, _partition: usize, - _runtime: Arc, + _context: Arc, ) -> Result { unimplemented!("This plan only serves for testing statistics") } @@ -171,8 +171,8 @@ impl ExecutionPlan for StatisticsValidation { } } -fn init_ctx(stats: Statistics, schema: Schema) -> Result { - let mut ctx = ExecutionContext::new(); +fn init_ctx(stats: Statistics, schema: Schema) -> Result { + let mut ctx = SessionContext::new(); let provider: Arc = Arc::new(StatisticsValidation::new(stats, Arc::new(schema))); ctx.register_table("stats_table", provider)?; diff --git a/datafusion/tests/user_defined_plan.rs b/datafusion/tests/user_defined_plan.rs index 17578047378a7..37c47969fbca9 100644 --- a/datafusion/tests/user_defined_plan.rs +++ b/datafusion/tests/user_defined_plan.rs @@ -69,8 +69,8 @@ use arrow::{ }; use datafusion::{ error::{DataFusionError, Result}, - execution::context::ExecutionContextState, execution::context::QueryPlanner, + execution::context::SessionState, logical_plan::{Expr, LogicalPlan, UserDefinedLogicalNode}, optimizer::{optimizer::OptimizerRule, utils::optimize_children}, physical_plan::{ @@ -79,21 +79,20 @@ use datafusion::{ DisplayFormatType, Distribution, ExecutionPlan, Partitioning, PhysicalPlanner, RecordBatchStream, SendableRecordBatchStream, Statistics, }, - prelude::{ExecutionConfig, ExecutionContext}, + prelude::{SessionConfig, SessionContext}, }; use fmt::Debug; use std::task::{Context, Poll}; use std::{any::Any, collections::BTreeMap, fmt, sync::Arc}; use async_trait::async_trait; -use datafusion::execution::context::ExecutionProps; -use datafusion::execution::runtime_env::RuntimeEnv; +use datafusion::execution::context::{ExecutionProps, TaskContext}; use datafusion::logical_plan::plan::{Extension, Sort}; use datafusion::logical_plan::{DFSchemaRef, Limit}; /// Execute the specified sql and return the resulting record batches /// pretty printed as a String. -async fn exec_sql(ctx: &mut ExecutionContext, sql: &str) -> Result { +async fn exec_sql(ctx: &mut SessionContext, sql: &str) -> Result { let df = ctx.sql(sql).await?; let batches = df.collect().await?; pretty_format_batches(&batches) @@ -102,7 +101,7 @@ async fn exec_sql(ctx: &mut ExecutionContext, sql: &str) -> Result { } /// Create a test table. -async fn setup_table(mut ctx: ExecutionContext) -> Result { +async fn setup_table(mut ctx: SessionContext) -> Result { let sql = "CREATE EXTERNAL TABLE sales(customer_id VARCHAR, revenue BIGINT) STORED AS CSV location 'tests/customer.csv'"; let expected = vec!["++", "++"]; @@ -114,9 +113,7 @@ async fn setup_table(mut ctx: ExecutionContext) -> Result { Ok(ctx) } -async fn setup_table_without_schemas( - mut ctx: ExecutionContext, -) -> Result { +async fn setup_table_without_schemas(mut ctx: SessionContext) -> Result { let sql = "CREATE EXTERNAL TABLE sales STORED AS CSV location 'tests/customer.csv'"; let expected = vec!["++", "++"]; @@ -135,10 +132,7 @@ const QUERY: &str = // Run the query using the specified execution context and compare it // to the known result -async fn run_and_compare_query( - mut ctx: ExecutionContext, - description: &str, -) -> Result<()> { +async fn run_and_compare_query(mut ctx: SessionContext, description: &str) -> Result<()> { let expected = vec![ "+-------------+---------+", "| customer_id | revenue |", @@ -166,7 +160,7 @@ async fn run_and_compare_query( // Run the query using the specified execution context and compare it // to the known result async fn run_and_compare_query_with_auto_schemas( - mut ctx: ExecutionContext, + mut ctx: SessionContext, description: &str, ) -> Result<()> { let expected = vec![ @@ -196,14 +190,14 @@ async fn run_and_compare_query_with_auto_schemas( #[tokio::test] // Run the query using default planners and optimizer async fn normal_query_without_schemas() -> Result<()> { - let ctx = setup_table_without_schemas(ExecutionContext::new()).await?; + let ctx = setup_table_without_schemas(SessionContext::new()).await?; run_and_compare_query_with_auto_schemas(ctx, "Default context").await } #[tokio::test] // Run the query using default planners and optimizer async fn normal_query() -> Result<()> { - let ctx = setup_table(ExecutionContext::new()).await?; + let ctx = setup_table(SessionContext::new()).await?; run_and_compare_query(ctx, "Default context").await } @@ -247,13 +241,13 @@ async fn topk_plan() -> Result<()> { Ok(()) } -fn make_topk_context() -> ExecutionContext { - let config = ExecutionConfig::new() +fn make_topk_context() -> SessionContext { + let config = SessionConfig::new() .with_query_planner(Arc::new(TopKQueryPlanner {})) .with_target_partitions(48) .add_optimizer_rule(Arc::new(TopKOptimizerRule {})); - ExecutionContext::with_config(config) + SessionContext::with_config(config) } // ------ The implementation of the TopK code follows ----- @@ -267,7 +261,7 @@ impl QueryPlanner for TopKQueryPlanner { async fn create_physical_plan( &self, logical_plan: &LogicalPlan, - ctx_state: &ExecutionContextState, + session_state: &SessionState, ) -> Result> { // Teach the default physical planner how to plan TopK nodes. let physical_planner = @@ -276,7 +270,7 @@ impl QueryPlanner for TopKQueryPlanner { )]); // Delegate most work of physical planning to the default physical planner physical_planner - .create_physical_plan(logical_plan, ctx_state) + .create_physical_plan(logical_plan, session_state) .await } } @@ -386,7 +380,7 @@ impl ExtensionPlanner for TopKPlanner { node: &dyn UserDefinedLogicalNode, logical_inputs: &[&LogicalPlan], physical_inputs: &[Arc], - _ctx_state: &ExecutionContextState, + _session_state: &SessionState, ) -> Result>> { Ok( if let Some(topk_node) = node.as_any().downcast_ref::() { @@ -468,7 +462,7 @@ impl ExecutionPlan for TopKExec { async fn execute( &self, partition: usize, - runtime: Arc, + context: Arc, ) -> Result { if 0 != partition { return Err(DataFusionError::Internal(format!( @@ -478,7 +472,7 @@ impl ExecutionPlan for TopKExec { } Ok(Box::pin(TopKReader { - input: self.input.execute(partition, runtime).await?, + input: self.input.execute(partition, context).await?, k: self.k, done: false, state: BTreeMap::new(), diff --git a/docs/source/python/api/execution_context.rst b/docs/source/python/api/execution_context.rst index 7f8c840ca0ad9..5b7e0f82f996f 100644 --- a/docs/source/python/api/execution_context.rst +++ b/docs/source/python/api/execution_context.rst @@ -18,10 +18,10 @@ .. _api.execution_context: .. currentmodule:: datafusion -ExecutionContext +SessionContext ================ .. autosummary:: :toctree: ../generated/ - ExecutionContext + SessionContext diff --git a/docs/source/python/index.rst b/docs/source/python/index.rst index 97221868d4b9f..3cafc550b78a2 100644 --- a/docs/source/python/index.rst +++ b/docs/source/python/index.rst @@ -44,7 +44,7 @@ Simple usage: import pyarrow # create a context - ctx = datafusion.ExecutionContext() + ctx = datafusion.SessionContext() # create a RecordBatch and a new DataFrame from it batch = pyarrow.RecordBatch.from_arrays( diff --git a/docs/source/user-guide/distributed/clients/rust.md b/docs/source/user-guide/distributed/clients/rust.md index ccf19aa70e3cc..bdb8d2c412b30 100644 --- a/docs/source/user-guide/distributed/clients/rust.md +++ b/docs/source/user-guide/distributed/clients/rust.md @@ -20,7 +20,7 @@ # Ballista Rust Client Ballista usage is very similar to DataFusion. Tha main difference is that the starting point is a `BallistaContext` -instead of the DataFusion `ExecutionContext`. Ballista uses the same DataFrame API as DataFusion. +instead of the DataFusion `SessionContext`. Ballista uses the same DataFrame API as DataFusion. The following code sample demonstrates how to create a `BallistaContext` to connect to a Ballista scheduler process. diff --git a/docs/source/user-guide/example-usage.md b/docs/source/user-guide/example-usage.md index 77930260e038f..4f2a6cf1a01f4 100644 --- a/docs/source/user-guide/example-usage.md +++ b/docs/source/user-guide/example-usage.md @@ -36,7 +36,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // register the table - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?; // create a plan to run a SQL query @@ -56,7 +56,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // create the dataframe - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?; let df = df.filter(col("a").lt_eq(col("b")))? diff --git a/docs/source/user-guide/library.md b/docs/source/user-guide/library.md index f0be42c972f00..5f4224816455d 100644 --- a/docs/source/user-guide/library.md +++ b/docs/source/user-guide/library.md @@ -57,7 +57,7 @@ use datafusion::prelude::*; #[tokio::main] async fn main() -> datafusion::error::Result<()> { // register the table - let mut ctx = ExecutionContext::new(); + let mut ctx = SessionContext::new(); ctx.register_csv("test", "", CsvReadOptions::new()).await?; // create a plan to run a SQL query