diff --git a/Cargo.toml b/Cargo.toml index 2a3e66323..c02f4bf68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,15 +21,15 @@ members = ["ballista-cli", "ballista/client", "ballista/core", "ballista/executo resolver = "2" [workspace.dependencies] -arrow = { version = "53", features = ["ipc_compression"] } -arrow-flight = { version = "53", features = ["flight-sql-experimental"] } +arrow = { version = "54", features = ["ipc_compression"] } +arrow-flight = { version = "54", features = ["flight-sql-experimental"] } clap = { version = "4.5", features = ["derive", "cargo"] } configure_me = { version = "0.4.0" } configure_me_codegen = { version = "0.4.4" } -datafusion = "44.0.0" -datafusion-cli = "44.0.0" -datafusion-proto = "44.0.0" -datafusion-proto-common = "44.0.0" +datafusion = "45.0.0" +datafusion-cli = "45.0.0" +datafusion-proto = "45.0.0" +datafusion-proto-common = "45.0.0" object_store = "0.11" prost = "0.13" prost-types = "0.13" @@ -45,15 +45,15 @@ ctor = { version = "0.2" } mimalloc = { version = "0.1" } tokio = { version = "1" } -uuid = { version = "1.10", features = ["v4", "v7"] } -rand = { version = "0.8" } +uuid = { version = "1.13", features = ["v4", "v7"] } +rand = { version = "0.9" } env_logger = { version = "0.11" } futures = { version = "0.3" } log = { version = "0.4" } parking_lot = { version = "0.12" } -tempfile = { version = "3" } +tempfile = { version = "3.16" } dashmap = { version = "6.1" } -async-trait = { version = "0.1.4" } +async-trait = { version = "0.1" } serde = { version = "1.0" } tokio-stream = { version = "0.1" } url = { version = "2.5" } diff --git a/ballista-cli/Cargo.toml b/ballista-cli/Cargo.toml index 46587d187..7061e8216 100644 --- a/ballista-cli/Cargo.toml +++ b/ballista-cli/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "ballista-cli" description = "Command Line Client for Ballista distributed query engine." -version = "44.0.0" +version = "45.0.0" authors = ["Apache DataFusion "] edition = "2021" keywords = ["ballista", "cli"] @@ -28,14 +28,14 @@ repository = "https://github.com/apache/arrow-ballista" readme = "README.md" [dependencies] -ballista = { path = "../ballista/client", version = "44.0.0", features = ["standalone"] } +ballista = { path = "../ballista/client", version = "45.0.0", features = ["standalone"] } clap = { workspace = true, features = ["derive", "cargo"] } datafusion = { workspace = true } datafusion-cli = { workspace = true } -dirs = "5.0.1" +dirs = "6.0" env_logger = { workspace = true } mimalloc = { workspace = true } -rustyline = "14.0.0" +rustyline = "15.0.0" tokio = { workspace = true, features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] } [features] diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index a509a30f1..106b7b672 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -19,7 +19,7 @@ name = "ballista" description = "Ballista Distributed Compute" license = "Apache-2.0" -version = "44.0.0" +version = "45.0.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" readme = "README.md" @@ -28,9 +28,9 @@ edition = "2021" [dependencies] async-trait = { workspace = true } -ballista-core = { path = "../core", version = "44.0.0" } -ballista-executor = { path = "../executor", version = "44.0.0", optional = true } -ballista-scheduler = { path = "../scheduler", version = "44.0.0", optional = true } +ballista-core = { path = "../core", version = "45.0.0" } +ballista-executor = { path = "../executor", version = "45.0.0", optional = true } +ballista-scheduler = { path = "../scheduler", version = "45.0.0", optional = true } datafusion = { workspace = true } log = { workspace = true } @@ -38,8 +38,8 @@ tokio = { workspace = true } url = { workspace = true } [dev-dependencies] -ballista-executor = { path = "../executor", version = "44.0.0" } -ballista-scheduler = { path = "../scheduler", version = "44.0.0" } +ballista-executor = { path = "../executor", version = "45.0.0" } +ballista-scheduler = { path = "../scheduler", version = "45.0.0" } ctor = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } diff --git a/ballista/client/tests/context_checks.rs b/ballista/client/tests/context_checks.rs index a46f93803..2601293de 100644 --- a/ballista/client/tests/context_checks.rs +++ b/ballista/client/tests/context_checks.rs @@ -365,4 +365,103 @@ mod supported { Ok(()) } + + /// looks like `ctx.enable_url_table()` changes session context id. + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + async fn should_execute_sql_show_with_url_table( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + let ctx = ctx.enable_url_table(); + + let result = ctx + .sql(&format!("select string_col, timestamp_col from '{test_data}/alltypes_plain.parquet' where id > 4")) + .await + .unwrap() + .collect() + .await + .unwrap(); + + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } + + #[rstest] + #[case::standalone(standalone_context())] + #[case::remote(remote_context())] + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_support_sql_insert_into( + #[future(awt)] + #[case] + ctx: SessionContext, + test_data: String, + ) { + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await + .unwrap(); + + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await + .unwrap() + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await + .unwrap(); + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await + .unwrap(); + + ctx.sql("INSERT INTO written_table select * from test") + .await + .unwrap() + .show() + .await + .unwrap(); + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await.unwrap() + .collect() + .await.unwrap(); + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + } } diff --git a/ballista/client/tests/context_unsupported.rs b/ballista/client/tests/context_unsupported.rs index 347071e55..aa9827993 100644 --- a/ballista/client/tests/context_unsupported.rs +++ b/ballista/client/tests/context_unsupported.rs @@ -144,112 +144,6 @@ mod unsupported { "+----+----------+---------------------+", ]; - assert_batches_eq!(expected, &result); - } - #[rstest] - #[case::standalone(standalone_context())] - #[case::remote(remote_context())] - #[tokio::test] - #[should_panic] - // "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))" - async fn should_support_sql_insert_into( - #[future(awt)] - #[case] - ctx: SessionContext, - test_data: String, - ) { - ctx.register_parquet( - "test", - &format!("{test_data}/alltypes_plain.parquet"), - Default::default(), - ) - .await - .unwrap(); - let write_dir = tempfile::tempdir().expect("temporary directory to be created"); - let write_dir_path = write_dir - .path() - .to_str() - .expect("path to be converted to str"); - - ctx.sql("select * from test") - .await - .unwrap() - .write_parquet(write_dir_path, Default::default(), Default::default()) - .await - .unwrap(); - - ctx.register_parquet("written_table", write_dir_path, Default::default()) - .await - .unwrap(); - - let _ = ctx - .sql("INSERT INTO written_table select * from written_table") - .await - .unwrap() - .collect() - .await - .unwrap(); - - let result = ctx - .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") - .await.unwrap() - .collect() - .await.unwrap(); - - let expected = [ - "+----+------------+---------------------+", - "| id | string_col | timestamp_col |", - "+----+------------+---------------------+", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 5 | 31 | 2009-03-01T00:01:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 6 | 30 | 2009-04-01T00:00:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "| 7 | 31 | 2009-04-01T00:01:00 |", - "+----+------------+---------------------+", - ]; - - assert_batches_eq!(expected, &result); - } - - /// looks like `ctx.enable_url_table()` changes session context id. - /// - /// Error returned: - /// ``` - /// Failed to load SessionContext for session ID b5530099-63d1-43b1-9e11-87ac83bb33e5: - /// General error: No session for b5530099-63d1-43b1-9e11-87ac83bb33e5 found - /// ``` - #[rstest] - #[case::standalone(standalone_context())] - #[case::remote(remote_context())] - #[tokio::test] - #[should_panic] - async fn should_execute_sql_show_with_url_table( - #[future(awt)] - #[case] - ctx: SessionContext, - test_data: String, - ) { - let ctx = ctx.enable_url_table(); - - let result = ctx - .sql(&format!("select string_col, timestamp_col from '{test_data}/alltypes_plain.parquet' where id > 4")) - .await - .unwrap() - .collect() - .await - .unwrap(); - - let expected = [ - "+------------+---------------------+", - "| string_col | timestamp_col |", - "+------------+---------------------+", - "| 31 | 2009-03-01T00:01:00 |", - "| 30 | 2009-04-01T00:00:00 |", - "| 31 | 2009-04-01T00:01:00 |", - "+------------+---------------------+", - ]; - assert_batches_eq!(expected, &result); } } diff --git a/ballista/core/Cargo.toml b/ballista/core/Cargo.toml index 731b32c0e..4f67525a5 100644 --- a/ballista/core/Cargo.toml +++ b/ballista/core/Cargo.toml @@ -19,7 +19,7 @@ name = "ballista-core" description = "Ballista Distributed Compute" license = "Apache-2.0" -version = "44.0.0" +version = "45.0.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" readme = "README.md" diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index cb7f7c5d7..74256489a 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -32,6 +32,13 @@ pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.paralleli /// max message size for gRPC clients pub const BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE: &str = "ballista.grpc_client_max_message_size"; +/// enable or disable ballista dml planner extension. +/// when enabled planner will use custom logical planner DML +/// extension which will serialize table provider used in DML +/// +/// this configuration should be disabled if using remote schema +/// registries. +pub const BALLISTA_PLANNER_DML_EXTENSION: &str = "ballista.planner.dml_extension"; pub type ParseResult = result::Result; use std::sync::LazyLock; @@ -48,6 +55,10 @@ static CONFIG_ENTRIES: LazyLock> = LazyLock::new(|| "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, Some((16 * 1024 * 1024).to_string())), + ConfigEntry::new(BALLISTA_PLANNER_DML_EXTENSION.to_string(), + "Enable ballista planner DML extension".to_string(), + DataType::Boolean, + Some((true).to_string())), ]; entries .into_iter() @@ -165,6 +176,10 @@ impl BallistaConfig { self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM) } + pub fn planner_dml_extension(&self) -> bool { + self.get_bool_setting(BALLISTA_PLANNER_DML_EXTENSION) + } + fn get_usize_setting(&self, key: &str) -> usize { if let Some(v) = self.settings.get(key) { // infallible because we validate all configs in the constructor diff --git a/ballista/core/src/execution_plans/shuffle_reader.rs b/ballista/core/src/execution_plans/shuffle_reader.rs index 7a20f1215..a9521dc88 100644 --- a/ballista/core/src/execution_plans/shuffle_reader.rs +++ b/ballista/core/src/execution_plans/shuffle_reader.rs @@ -50,7 +50,7 @@ use datafusion::physical_plan::stream::RecordBatchStreamAdapter; use itertools::Itertools; use log::{error, info}; use rand::prelude::SliceRandom; -use rand::thread_rng; +use rand::rng; use tokio::sync::{mpsc, Semaphore}; use tokio_stream::wrappers::ReceiverStream; @@ -163,7 +163,7 @@ impl ExecutionPlan for ShuffleReaderExec { .map(|(_, p)| p) .collect(); // Shuffle partitions for evenly send fetching partition requests to avoid hot executors within multiple tasks - partition_locations.shuffle(&mut thread_rng()); + partition_locations.shuffle(&mut rng()); let response_receiver = send_fetch_partitions(partition_locations, max_request_num); diff --git a/ballista/core/src/planner.rs b/ballista/core/src/planner.rs index 266da3c6e..77690cd44 100644 --- a/ballista/core/src/planner.rs +++ b/ballista/core/src/planner.rs @@ -17,14 +17,16 @@ use crate::config::BallistaConfig; use crate::execution_plans::DistributedQueryExec; -use crate::serde::BallistaLogicalExtensionCodec; +use crate::serde::{BallistaDmlExtension, BallistaLogicalExtensionCodec}; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; +use datafusion::common::plan_err; use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor}; +use datafusion::datasource::DefaultTableSource; use datafusion::error::DataFusionError; use datafusion::execution::context::{QueryPlanner, SessionState}; -use datafusion::logical_expr::{LogicalPlan, TableScan}; +use datafusion::logical_expr::{DmlStatement, Extension, LogicalPlan, TableScan}; use datafusion::physical_plan::empty::EmptyExec; use datafusion::physical_plan::ExecutionPlan; use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; @@ -125,6 +127,41 @@ impl QueryPlanner for BallistaQueryPlanner { log::debug!("create_physical_plan - handling empty exec"); Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) } + // At the moment DML statement uses TableReference instead of TableProvider. + // As ballista has two contexts (client and scheduler) scheduler context may not + // know table provider for given table reference, thus we need to attach + // table provider to this DML statement. + LogicalPlan::Dml(DmlStatement { table_name, .. }) + if self.config.planner_dml_extension() => + { + let table_name = table_name.to_owned(); + let table = table_name.table().to_string(); + let schema = session_state.schema_for_ref(table_name.clone())?; + let table_provider = match schema.table(&table).await? { + Some(ref provider) => Ok(Arc::clone(provider)), + _ => plan_err!("No table named '{table}'"), + }?; + + let table_source = Arc::new(DefaultTableSource::new(table_provider)); + let table = + TableScan::try_new(table_name, table_source, None, vec![], None)?; + + // custom made logical extension node is used to attach table reference + let node = Arc::new(BallistaDmlExtension { + dml: logical_plan.clone(), + table, + }); + let plan = LogicalPlan::Extension(Extension { node }); + log::debug!("create_physical_plan - handling DML statement"); + + Ok(Arc::new(DistributedQueryExec::::with_extension( + self.scheduler_url.clone(), + self.config.clone(), + plan.clone(), + self.extension_codec.clone(), + session_state.session_id().to_string(), + ))) + } _ => { log::debug!("create_physical_plan - handling general statement"); diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index 84cf80684..95bd6084e 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -22,8 +22,11 @@ use crate::{error::BallistaError, serde::scheduler::Action as BallistaAction}; use arrow_flight::sql::ProstMessageExt; use datafusion::arrow::datatypes::SchemaRef; -use datafusion::common::{DataFusionError, Result}; +use datafusion::common::{plan_err, DataFusionError, Result}; use datafusion::execution::FunctionRegistry; +use datafusion::logical_expr::{ + Extension, LogicalPlan, TableScan, UserDefinedLogicalNodeCore, +}; use datafusion::physical_plan::{ExecutionPlan, Partitioning}; use datafusion_proto::logical_plan::file_formats::{ ArrowLogicalExtensionCodec, AvroLogicalExtensionCodec, CsvLogicalExtensionCodec, @@ -179,7 +182,31 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { inputs: &[datafusion::logical_expr::LogicalPlan], ctx: &datafusion::prelude::SessionContext, ) -> Result { - self.default_codec.try_decode(buf, inputs, ctx) + match BallistaExtensionProto::decode(buf) { + Ok(extension) => match extension.extension { + Some(BallistaExtensionType::Dml(BallistaDmlExtensionProto { + dml: Some(dml), + table: Some(table), + })) => { + let table = table.try_into_logical_plan(ctx, self)?; + match table { + LogicalPlan::TableScan(scan) => { + let dml = dml.try_into_logical_plan(ctx, self)?; + Ok(Extension { + node: Arc::new(BallistaDmlExtension { dml, table: scan }), + }) + } + _ => plan_err!( + "TableScan expected in ballista DML extension definition" + ), + } + } + None => plan_err!("Ballista extension can't be None"), + _ => plan_err!("Ballista extension not supported"), + }, + + Err(_e) => self.default_codec.try_decode(buf, inputs, ctx), + } } fn try_encode( @@ -187,7 +214,32 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { node: &datafusion::logical_expr::Extension, buf: &mut Vec, ) -> Result<()> { - self.default_codec.try_encode(node, buf) + if let Some(BallistaDmlExtension { dml: input, table }) = + node.node.as_any().downcast_ref::() + { + let input = LogicalPlanNode::try_from_logical_plan(input, self)?; + + let table = LogicalPlanNode::try_from_logical_plan( + &LogicalPlan::TableScan(table.clone()), + self, + )?; + let extension = BallistaDmlExtensionProto { + dml: Some(input), + table: Some(table), + }; + + let extension = BallistaExtensionProto { + extension: Some(BallistaExtensionType::Dml(extension)), + }; + + extension + .encode(buf) + .map_err(|e| DataFusionError::Execution(e.to_string()))?; + + Ok(()) + } else { + self.default_codec.try_encode(node, buf) + } } fn try_decode_table_provider( @@ -487,6 +539,74 @@ struct FileFormatProto { pub blob: Vec, } +#[derive(Clone, PartialEq, prost::Message)] +struct BallistaExtensionProto { + #[prost(oneof = "BallistaExtensionType", tags = "1")] + extension: Option, +} + +#[derive(Clone, PartialEq, ::prost::Oneof)] +enum BallistaExtensionType { + #[prost(message, tag = "1")] + Dml(BallistaDmlExtensionProto), +} + +#[derive(Clone, PartialEq, prost::Message)] +struct BallistaDmlExtensionProto { + #[prost(message, tag = 1)] + pub dml: Option, + #[prost(message, tag = 2)] + pub table: Option, +} + +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct BallistaDmlExtension { + /// LogicalPlan::DML + /// DMLStatement is expected + pub dml: LogicalPlan, + /// Table provider which is referenced + /// from LogicalPlan::DML + pub table: TableScan, +} + +impl std::cmp::PartialOrd for BallistaDmlExtension { + fn partial_cmp(&self, other: &Self) -> Option { + self.dml.partial_cmp(&other.dml) + } +} +impl UserDefinedLogicalNodeCore for BallistaDmlExtension { + fn name(&self) -> &str { + "BallistaDmlExtension" + } + + fn inputs(&self) -> Vec<&datafusion::logical_expr::LogicalPlan> { + vec![&self.dml] + } + + fn schema(&self) -> &datafusion::common::DFSchemaRef { + self.dml.schema() + } + + fn expressions(&self) -> Vec { + self.dml.expressions() + } + + fn fmt_for_explain(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.dml.fmt(f) + } + + fn with_exprs_and_inputs( + &self, + _exprs: Vec, + inputs: Vec, + ) -> Result { + Ok(Self { + dml: inputs[0].clone(), + table: self.table.clone(), + }) + } +} + #[cfg(test)] mod test { use super::*; diff --git a/ballista/executor/Cargo.toml b/ballista/executor/Cargo.toml index facde01ed..d5bd5efb5 100644 --- a/ballista/executor/Cargo.toml +++ b/ballista/executor/Cargo.toml @@ -19,7 +19,7 @@ name = "ballista-executor" description = "Ballista Distributed Compute - Executor" license = "Apache-2.0" -version = "44.0.0" +version = "45.0.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" readme = "README.md" @@ -42,7 +42,7 @@ default = ["build-binary", "mimalloc"] arrow = { workspace = true } arrow-flight = { workspace = true } async-trait = { workspace = true } -ballista-core = { path = "../core", version = "44.0.0" } +ballista-core = { path = "../core", version = "45.0.0" } configure_me = { workspace = true, optional = true } dashmap = { workspace = true } datafusion = { workspace = true } diff --git a/ballista/scheduler/Cargo.toml b/ballista/scheduler/Cargo.toml index 585dab985..ebfcbde38 100644 --- a/ballista/scheduler/Cargo.toml +++ b/ballista/scheduler/Cargo.toml @@ -19,7 +19,7 @@ name = "ballista-scheduler" description = "Ballista Distributed Compute - Scheduler" license = "Apache-2.0" -version = "44.0.0" +version = "45.0.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" readme = "README.md" @@ -46,7 +46,7 @@ rest-api = ["graphviz-rust"] arrow-flight = { workspace = true } async-trait = { workspace = true } axum = "0.7.7" -ballista-core = { path = "../core", version = "44.0.0" } +ballista-core = { path = "../core", version = "45.0.0" } base64 = { version = "0.22", optional = true } clap = { workspace = true, optional = true } configure_me = { workspace = true, optional = true } diff --git a/ballista/scheduler/src/planner.rs b/ballista/scheduler/src/planner.rs index a01267091..7d8a19bd0 100644 --- a/ballista/scheduler/src/planner.rs +++ b/ballista/scheduler/src/planner.rs @@ -556,7 +556,7 @@ order by let join = coalesce_batches.children()[0].clone(); let join = downcast_exec!(join, HashJoinExec); - assert!(join.contain_projection()); + assert!(join.contains_projection()); let join_input_1 = join.children()[0].clone(); // skip CoalesceBatches @@ -687,7 +687,7 @@ order by assert_eq!(Some(&Column::new("l_shipmode", 1)), partition_by); assert_eq!(InputOrderMode::Sorted, window.input_order_mode); let sort = downcast_exec!(window.children()[0], SortExec); - match &sort.expr().inner[..] { + match &sort.expr().iter().collect::>()[..] { [expr1, expr2] => { assert_eq!( SortOptions { diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 02c21a884..4a65aed3b 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -32,6 +32,9 @@ use ballista_core::serde::protobuf::{ UpdateTaskStatusParams, UpdateTaskStatusResult, }; use ballista_core::serde::scheduler::ExecutorMetadata; +use ballista_core::serde::BallistaDmlExtension; +use datafusion::datasource::DefaultTableSource; +use datafusion::logical_expr::{Extension, LogicalPlan}; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, trace, warn}; @@ -409,6 +412,34 @@ impl SchedulerGrpc self.state.codec.logical_extension_codec(), ) }) { + Ok(LogicalPlan::Extension(Extension { node })) + if node + .as_any() + .downcast_ref::() + .is_some() => + { + let plan = node + .as_any() + .downcast_ref::() + .unwrap(); + + let table_provider = &plan + .table + .source + .as_any() + .downcast_ref::() + .expect("Default Table Source is expected") + .table_provider; + + let _ = session_ctx + .deregister_table(plan.table.table_name.clone()); + let _ = session_ctx.register_table( + plan.table.table_name.clone(), + table_provider.clone(), + ); + + plan.dml.clone() + } Ok(plan) => plan, Err(e) => { let msg = diff --git a/ballista/scheduler/src/state/task_manager.rs b/ballista/scheduler/src/state/task_manager.rs index 8bad64f62..53a352bd1 100644 --- a/ballista/scheduler/src/state/task_manager.rs +++ b/ballista/scheduler/src/state/task_manager.rs @@ -26,6 +26,7 @@ use ballista_core::error::BallistaError; use ballista_core::error::Result; use ballista_core::extension::SessionConfigHelperExt; use datafusion::prelude::SessionConfig; +use rand::distr::Alphanumeric; use crate::cluster::JobState; use ballista_core::serde::protobuf::{ @@ -39,8 +40,7 @@ use datafusion::physical_plan::ExecutionPlan; use datafusion_proto::logical_plan::AsLogicalPlan; use datafusion_proto::physical_plan::AsExecutionPlan; use log::{debug, error, info, trace, warn}; -use rand::distributions::Alphanumeric; -use rand::{thread_rng, Rng}; +use rand::{rng, Rng}; use std::collections::{HashMap, HashSet}; use std::ops::Deref; use std::sync::Arc; @@ -644,7 +644,7 @@ impl TaskManager /// Generate a new random Job ID pub fn generate_job_id(&self) -> String { - let mut rng = thread_rng(); + let mut rng = rng(); std::iter::repeat(()) .map(|()| rng.sample(Alphanumeric)) .map(char::from) diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml index 4d851cdcb..0c0655486 100644 --- a/benchmarks/Cargo.toml +++ b/benchmarks/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "ballista-benchmarks" description = "Ballista Benchmarks" -version = "44.0.0" +version = "45.0.0" edition = "2021" authors = ["Apache DataFusion "] homepage = "https://github.com/apache/arrow-ballista" @@ -32,7 +32,7 @@ default = ["mimalloc"] snmalloc = ["snmalloc-rs"] [dependencies] -ballista = { path = "../ballista/client", version = "44.0.0" } +ballista = { path = "../ballista/client", version = "45.0.0" } datafusion = { workspace = true } datafusion-proto = { workspace = true } env_logger = { workspace = true } @@ -51,4 +51,4 @@ tokio = { version = "^1.0", features = [ ] } [dev-dependencies] -ballista-core = { path = "../ballista/core", version = "44.0.0" } +ballista-core = { path = "../ballista/core", version = "45.0.0" } diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 72cc848df..1e9f4a37f 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -501,7 +501,7 @@ async fn loadtest_ballista(opt: BallistaLoadtestOpt) -> Result<()> { let query_id = query_list_clone .get( (0..query_list_clone.len()) - .choose(&mut rand::thread_rng()) + .choose(&mut rand::rng()) .unwrap(), ) .unwrap(); diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 1df640fcb..266d01018 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -18,7 +18,7 @@ [package] name = "ballista-examples" description = "Ballista usage examples" -version = "44.0.0" +version = "45.0.0" homepage = "https://github.com/apache/arrow-ballista" repository = "https://github.com/apache/arrow-ballista" authors = ["Apache DataFusion "] @@ -33,10 +33,10 @@ path = "examples/standalone-sql.rs" required-features = ["ballista/standalone"] [dependencies] -ballista = { path = "../ballista/client", version = "44.0.0" } -ballista-core = { path = "../ballista/core", version = "44.0.0" } -ballista-executor = { path = "../ballista/executor", version = "44.0.0", default-features = false } -ballista-scheduler = { path = "../ballista/scheduler", version = "44.0.0", default-features = false } +ballista = { path = "../ballista/client", version = "45.0.0" } +ballista-core = { path = "../ballista/core", version = "45.0.0" } +ballista-executor = { path = "../ballista/executor", version = "45.0.0", default-features = false } +ballista-scheduler = { path = "../ballista/scheduler", version = "45.0.0", default-features = false } datafusion = { workspace = true } env_logger = { workspace = true } log = { workspace = true }