diff --git a/datafusion/__init__.py b/datafusion/__init__.py index b2e1028f2..ddab950be 100644 --- a/datafusion/__init__.py +++ b/datafusion/__init__.py @@ -30,6 +30,8 @@ Config, DataFrame, SessionContext, + SessionConfig, + RuntimeConfig, Expression, ScalarUDF, ) @@ -40,6 +42,8 @@ "Config", "DataFrame", "SessionContext", + "SessionConfig", + "RuntimeConfig", "Expression", "AggregateUDF", "ScalarUDF", diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py index 55849edf0..48d41c114 100644 --- a/datafusion/tests/test_context.py +++ b/datafusion/tests/test_context.py @@ -20,7 +20,13 @@ import pyarrow as pa import pyarrow.dataset as ds -from datafusion import column, literal, SessionContext +from datafusion import ( + column, + literal, + SessionContext, + SessionConfig, + RuntimeConfig, +) import pytest @@ -29,19 +35,24 @@ def test_create_context_no_args(): def test_create_context_with_all_valid_args(): - ctx = SessionContext( - target_partitions=1, - default_catalog="foo", - default_schema="bar", - create_default_catalog_and_schema=True, - information_schema=True, - repartition_joins=False, - repartition_aggregations=False, - repartition_windows=False, - parquet_pruning=False, - config_options=None, + + runtime = ( + RuntimeConfig().with_disk_manager_os().with_fair_spill_pool(10000000) + ) + config = ( + SessionConfig() + .with_create_default_catalog_and_schema(True) + .with_default_catalog_and_schema("foo", "bar") + .with_target_partitions(1) + .with_information_schema(True) + .with_repartition_joins(False) + .with_repartition_aggregations(False) + .with_repartition_windows(False) + .with_parquet_pruning(False) ) + ctx = SessionContext(config, runtime) + # verify that at least some of the arguments worked ctx.catalog("foo").database("bar") with pytest.raises(KeyError): diff --git a/src/context.rs b/src/context.rs index c50d0392a..8dcd1d6ff 100644 --- a/src/context.rs +++ b/src/context.rs @@ -40,11 +40,161 @@ use datafusion::arrow::record_batch::RecordBatch; use datafusion::datasource::datasource::TableProvider; use datafusion::datasource::MemTable; use datafusion::execution::context::{SessionConfig, SessionContext}; +use datafusion::execution::disk_manager::DiskManagerConfig; +use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool}; +use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::prelude::{ AvroReadOptions, CsvReadOptions, DataFrame, NdJsonReadOptions, ParquetReadOptions, }; use datafusion_common::ScalarValue; +#[pyclass(name = "SessionConfig", module = "datafusion", subclass, unsendable)] +#[derive(Clone, Default)] +pub(crate) struct PySessionConfig { + pub(crate) config: SessionConfig, +} + +impl From for PySessionConfig { + fn from(config: SessionConfig) -> Self { + Self { config } + } +} + +#[pymethods] +impl PySessionConfig { + #[pyo3(signature = (config_options=None))] + #[new] + fn new(config_options: Option>) -> Self { + let mut config = SessionConfig::new(); + if let Some(hash_map) = config_options { + for (k, v) in &hash_map { + config = config.set(k, ScalarValue::Utf8(Some(v.clone()))); + } + } + + Self { config } + } + + fn with_create_default_catalog_and_schema(&self, enabled: bool) -> Self { + Self::from( + self.config + .clone() + .with_create_default_catalog_and_schema(enabled), + ) + } + + fn with_default_catalog_and_schema(&self, catalog: &str, schema: &str) -> Self { + Self::from( + self.config + .clone() + .with_default_catalog_and_schema(catalog, schema), + ) + } + + fn with_information_schema(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_information_schema(enabled)) + } + + fn with_batch_size(&self, batch_size: usize) -> Self { + Self::from(self.config.clone().with_batch_size(batch_size)) + } + + fn with_target_partitions(&self, target_partitions: usize) -> Self { + Self::from( + self.config + .clone() + .with_target_partitions(target_partitions), + ) + } + + fn with_repartition_aggregations(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_repartition_aggregations(enabled)) + } + + fn with_repartition_joins(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_repartition_joins(enabled)) + } + + fn with_repartition_windows(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_repartition_windows(enabled)) + } + + fn with_repartition_sorts(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_repartition_sorts(enabled)) + } + + fn with_repartition_file_scans(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_repartition_file_scans(enabled)) + } + + fn with_repartition_file_min_size(&self, size: usize) -> Self { + Self::from(self.config.clone().with_repartition_file_min_size(size)) + } + + fn with_parquet_pruning(&self, enabled: bool) -> Self { + Self::from(self.config.clone().with_parquet_pruning(enabled)) + } +} + +#[pyclass(name = "RuntimeConfig", module = "datafusion", subclass, unsendable)] +#[derive(Clone)] +pub(crate) struct PyRuntimeConfig { + pub(crate) config: RuntimeConfig, +} + +#[pymethods] +impl PyRuntimeConfig { + #[new] + fn new() -> Self { + Self { + config: RuntimeConfig::default(), + } + } + + fn with_disk_manager_disabled(&self) -> Self { + let config = self.config.clone(); + let config = config.with_disk_manager(DiskManagerConfig::Disabled); + Self { config } + } + + fn with_disk_manager_os(&self) -> Self { + let config = self.config.clone(); + let config = config.with_disk_manager(DiskManagerConfig::NewOs); + Self { config } + } + + fn with_disk_manager_specified(&self, paths: Vec) -> Self { + let config = self.config.clone(); + let paths = paths.iter().map(|s| s.into()).collect(); + let config = config.with_disk_manager(DiskManagerConfig::NewSpecified(paths)); + Self { config } + } + + fn with_unbounded_memory_pool(&self) -> Self { + let config = self.config.clone(); + let config = config.with_memory_pool(Arc::new(UnboundedMemoryPool::default())); + Self { config } + } + + fn with_fair_spill_pool(&self, size: usize) -> Self { + let config = self.config.clone(); + let config = config.with_memory_pool(Arc::new(FairSpillPool::new(size))); + Self { config } + } + + fn with_greedy_memory_pool(&self, size: usize) -> Self { + let config = self.config.clone(); + let config = config.with_memory_pool(Arc::new(GreedyMemoryPool::new(size))); + Self { config } + } + + fn with_temp_file_path(&self, path: &str) -> Self { + let config = self.config.clone(); + let config = config.with_temp_file_path(path); + Self { config } + } +} + /// `PySessionContext` is able to plan and execute DataFusion plans. /// It has a powerful optimizer, a physical planner for local execution, and a /// multi-threaded execution engine to perform the execution. @@ -56,54 +206,22 @@ pub(crate) struct PySessionContext { #[pymethods] impl PySessionContext { - #[allow(clippy::too_many_arguments)] - #[pyo3(signature = (default_catalog="datafusion", - default_schema="public", - create_default_catalog_and_schema=true, - information_schema=false, - repartition_joins=true, - repartition_aggregations=true, - repartition_windows=true, - parquet_pruning=true, - target_partitions=None, - config_options=None))] + #[pyo3(signature = (config=None, runtime=None))] #[new] - fn new( - default_catalog: &str, - default_schema: &str, - create_default_catalog_and_schema: bool, - information_schema: bool, - repartition_joins: bool, - repartition_aggregations: bool, - repartition_windows: bool, - parquet_pruning: bool, - target_partitions: Option, - config_options: Option>, - ) -> PyResult { - let mut cfg = SessionConfig::new() - .with_information_schema(information_schema) - .with_repartition_joins(repartition_joins) - .with_repartition_aggregations(repartition_aggregations) - .with_repartition_windows(repartition_windows) - .with_parquet_pruning(parquet_pruning); - - if create_default_catalog_and_schema { - cfg = cfg.with_default_catalog_and_schema(default_catalog, default_schema); - } - - if let Some(hash_map) = config_options { - for (k, v) in &hash_map { - cfg = cfg.set(k, ScalarValue::Utf8(Some(v.clone()))); - } - } - - let cfg_full = match target_partitions { - None => cfg, - Some(x) => cfg.with_target_partitions(x), + fn new(config: Option, runtime: Option) -> PyResult { + let config = if let Some(c) = config { + c.config + } else { + SessionConfig::default() }; - + let runtime_config = if let Some(c) = runtime { + c.config + } else { + RuntimeConfig::default() + }; + let runtime = Arc::new(RuntimeEnv::new(runtime_config)?); Ok(PySessionContext { - ctx: SessionContext::with_config(cfg_full), + ctx: SessionContext::with_config_rt(config, runtime), }) } diff --git a/src/lib.rs b/src/lib.rs index be699d529..5391de57c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -60,6 +60,8 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;