diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs index c8c5dcc1c5e6b..dfe0ae6f0247f 100644 --- a/datafusion/src/dataframe.rs +++ b/datafusion/src/dataframe.rs @@ -17,15 +17,22 @@ //! DataFrame API for building and executing query plans. +use crate::arrow::datatypes::Schema; +use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; +use crate::datasource::TableProvider; +use crate::datasource::TableType; use crate::error::Result; +use crate::execution::dataframe_impl::DataFrameImpl; use crate::logical_plan::{ DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning, }; -use std::sync::Arc; - +use crate::physical_plan::ExecutionPlan; use crate::physical_plan::SendableRecordBatchStream; +use crate::scalar::ScalarValue; use async_trait::async_trait; +use std::any::Any; +use std::sync::Arc; /// DataFrame represents a logical set of rows with the same named columns. /// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or @@ -53,7 +60,7 @@ use async_trait::async_trait; /// # } /// ``` #[async_trait] -pub trait DataFrame: Send + Sync { +pub trait DataFrame: TableProvider + Send + Sync { /// Filter the DataFrame by column. Returns a new DataFrame only containing the /// specified columns. /// @@ -328,7 +335,7 @@ pub trait DataFrame: Send + Sync { /// where each column has a name, data type, and nullability attribute. /// ``` - /// # use datafusion::prelude::*; + /// # use datafusion::prelude::{CsvReadOptions, ExecutionContext, DataFrame}; /// # use datafusion::error::Result; /// # #[tokio::main] /// # async fn main() -> Result<()> { @@ -406,3 +413,60 @@ pub trait DataFrame: Send + Sync { /// ``` fn except(&self, dataframe: Arc) -> Result>; } + +#[async_trait] +impl TableProvider for D +where + D: DataFrame + 'static, +{ + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + let schema: Schema = self.to_logical_plan().schema().as_ref().into(); + Arc::new(schema) + } + + fn table_type(&self) -> TableType { + TableType::View + } + + async fn scan( + &self, + projection: &Option>, + filters: &[Expr], + limit: Option, + ) -> Result> { + let plan = self.to_logical_plan(); + let expr = projection + .as_ref() + // construct projections + .map_or_else( + || Ok(Arc::new(DataFrameImpl::new(Default::default(), &plan)) as Arc<_>), + |projection| { + let schema = TableProvider::schema(self).project(projection)?; + let names = schema + .fields() + .iter() + .map(|field| field.name().as_str()) + .collect::>(); + self.select_columns(names.as_slice()) + }, + )? + // add predicates, otherwise use `true` as the predicate + .filter(filters.iter().cloned().fold( + Expr::Literal(ScalarValue::Boolean(Some(true))), + |acc, new| acc.and(new), + ))?; + // add a limit if given + DataFrameImpl::new( + Default::default(), + &limit + .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? + .to_logical_plan(), + ) + .create_physical_plan() + .await + } +} diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs index 6ed8223f0c527..244343a6aa14c 100644 --- a/datafusion/src/execution/context.rs +++ b/datafusion/src/execution/context.rs @@ -1573,7 +1573,9 @@ mod tests { let tmp_dir = TempDir::new()?; let ctx = create_ctx(&tmp_dir, 1).await?; - let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); + let schema: Schema = DataFrame::schema(&*ctx.table("test").unwrap()) + .clone() + .into(); assert!(!schema.field_with_name("c1")?.is_nullable()); let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)? diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs index d3f62bbb46dbb..c1933adaa6416 100644 --- a/datafusion/src/execution/dataframe_impl.rs +++ b/datafusion/src/execution/dataframe_impl.rs @@ -17,11 +17,8 @@ //! Implementation of DataFrame API. -use std::any::Any; use std::sync::{Arc, Mutex}; -use crate::arrow::datatypes::Schema; -use crate::arrow::datatypes::SchemaRef; use crate::arrow::record_batch::RecordBatch; use crate::error::Result; use crate::execution::context::{ExecutionContext, ExecutionContextState}; @@ -29,15 +26,12 @@ use crate::logical_plan::{ col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder, Partitioning, }; -use crate::scalar::ScalarValue; use crate::{ dataframe::*, physical_plan::{collect, collect_partitioned}, }; use crate::arrow::util::pretty; -use crate::datasource::TableProvider; -use crate::datasource::TableType; use crate::physical_plan::{ execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, }; @@ -60,7 +54,7 @@ impl DataFrameImpl { } /// Create a physical plan - async fn create_physical_plan(&self) -> Result> { + pub(crate) async fn create_physical_plan(&self) -> Result> { let state = self.ctx_state.lock().unwrap().clone(); let ctx = ExecutionContext::from(Arc::new(Mutex::new(state))); let plan = ctx.optimize(&self.plan)?; @@ -68,59 +62,6 @@ impl DataFrameImpl { } } -#[async_trait] -impl TableProvider for DataFrameImpl { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - let schema: Schema = self.plan.schema().as_ref().into(); - Arc::new(schema) - } - - fn table_type(&self) -> TableType { - TableType::View - } - - async fn scan( - &self, - projection: &Option>, - filters: &[Expr], - limit: Option, - ) -> Result> { - let expr = projection - .as_ref() - // construct projections - .map_or_else( - || Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>), - |projection| { - let schema = TableProvider::schema(self).project(projection)?; - let names = schema - .fields() - .iter() - .map(|field| field.name().as_str()) - .collect::>(); - self.select_columns(names.as_slice()) - }, - )? - // add predicates, otherwise use `true` as the predicate - .filter(filters.iter().cloned().fold( - Expr::Literal(ScalarValue::Boolean(Some(true))), - |acc, new| acc.and(new), - ))?; - // add a limit if given - Self::new( - self.ctx_state.clone(), - &limit - .map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))? - .to_logical_plan(), - ) - .create_physical_plan() - .await - } -} - #[async_trait] impl DataFrame for DataFrameImpl { /// Apply a projection based on a list of column names @@ -602,6 +543,17 @@ mod tests { ); Ok(()) } + + #[tokio::test] + async fn register_dataframe() -> Result<()> { + let df = test_table().await?.select_columns(&["c1", "c12"])?; + let mut ctx = ExecutionContext::new(); + + // register a dataframe as a table + ctx.register_table("test_table", df)?; + Ok(()) + } + /// Compare the formatted string representation of two plans for equality fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) { assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));