diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index 08374bfcc5b21..43db654e83f9f 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -28,7 +28,6 @@ use std::{ use datafusion::datasource::{MemTable, TableProvider}; use datafusion::error::{DataFusionError, Result}; -use datafusion::logical_plan::LogicalPlan; use datafusion::parquet::basic::Compression; use datafusion::parquet::file::properties::WriterProperties; use datafusion::physical_plan::display::DisplayableExecutionPlan; @@ -196,10 +195,12 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result = Vec::with_capacity(1); for i in 0..opt.iterations { let start = Instant::now(); - let plans = create_logical_plans(&ctx, opt.query)?; - for plan in plans { - result = execute_query(&ctx, &plan, opt.debug).await?; + + let sql = &get_query_sql(opt.query)?; + for query in sql { + result = execute_query(&ctx, query, opt.debug).await?; } + let elapsed = start.elapsed().as_secs_f64() * 1000.0; millis.push(elapsed as f64); let row_count = result.iter().map(|b| b.num_rows()).sum(); @@ -253,7 +254,7 @@ fn get_query_sql(query: usize) -> Result> { .map(|s| s.trim()) .filter(|s| !s.is_empty()) .map(|s| s.to_string()) - .collect()) + .collect()); } Err(e) => errors.push(format!("{}: {}", filename, e)), }; @@ -269,23 +270,18 @@ fn get_query_sql(query: usize) -> Result> { } } -/// Create a logical plan for each query in the specified query file -fn create_logical_plans(ctx: &SessionContext, query: usize) -> Result> { - let sql = get_query_sql(query)?; - sql.iter() - .map(|sql| ctx.create_logical_plan(sql.as_str())) - .collect::>>() -} - async fn execute_query( ctx: &SessionContext, - plan: &LogicalPlan, + sql: &str, debug: bool, ) -> Result> { + let plan = ctx.sql(sql).await?; + let plan = plan.to_logical_plan()?; + if debug { println!("=== Logical plan ===\n{:?}\n", plan); } - let plan = ctx.optimize(plan)?; + let plan = ctx.optimize(&plan)?; if debug { println!("=== Optimized logical plan ===\n{:?}\n", plan); } @@ -357,7 +353,7 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> { return Err(DataFusionError::NotImplemented(format!( "Invalid compression format: {}", other - ))) + ))); } }; let props = WriterProperties::builder() @@ -369,7 +365,7 @@ async fn convert_tbl(opt: ConvertOpt) -> Result<()> { return Err(DataFusionError::NotImplemented(format!( "Invalid output format: {}", other - ))) + ))); } } println!("Conversion completed in {} ms", start.elapsed().as_millis()); @@ -1022,9 +1018,9 @@ mod tests { ctx.register_table(table, Arc::new(provider))?; } - let plans = create_logical_plans(&ctx, n)?; - for plan in plans { - execute_query(&ctx, &plan, false).await?; + let sql = &get_query_sql(n)?; + for query in sql { + execute_query(&ctx, query, false).await?; } Ok(())