Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 5 additions & 19 deletions datafusion/core/tests/sqllogictests/src/engines/datafusion/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,11 @@ mod normalize;
pub struct DataFusion {
ctx: SessionContext,
file_name: String,
is_pg_compatibility_test: bool,
}

impl DataFusion {
pub fn new(
ctx: SessionContext,
file_name: String,
postgres_compatible: bool,
) -> Self {
Self {
ctx,
file_name,
is_pg_compatibility_test: postgres_compatible,
}
pub fn new(ctx: SessionContext, file_name: String) -> Self {
Self { ctx, file_name }
}
}

Expand All @@ -57,7 +48,7 @@ impl sqllogictest::AsyncDB for DataFusion {

async fn run(&mut self, sql: &str) -> Result<DBOutput> {
println!("[{}] Running query: \"{}\"", self.file_name, sql);
let result = run_query(&self.ctx, sql, self.is_pg_compatibility_test).await?;
let result = run_query(&self.ctx, sql).await?;
Ok(result)
}

Expand All @@ -76,11 +67,7 @@ impl sqllogictest::AsyncDB for DataFusion {
}
}

async fn run_query(
ctx: &SessionContext,
sql: impl Into<String>,
is_pg_compatibility_test: bool,
) -> Result<DBOutput> {
async fn run_query(ctx: &SessionContext, sql: impl Into<String>) -> Result<DBOutput> {
let sql = sql.into();
// Check if the sql is `insert`
if let Ok(mut statements) = DFParser::parse_sql(&sql) {
Expand All @@ -94,7 +81,6 @@ async fn run_query(
}
let df = ctx.sql(sql.as_str()).await?;
let results: Vec<RecordBatch> = df.collect().await?;
let formatted_batches =
normalize::convert_batches(results, is_pg_compatibility_test)?;
let formatted_batches = normalize::convert_batches(results)?;
Ok(formatted_batches)
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ use super::error::{DFSqlLogicTestError, Result};
///
/// Assumes empty record batches are a successful statement completion
///
pub fn convert_batches(
batches: Vec<RecordBatch>,
is_pg_compatibility_test: bool,
) -> Result<DBOutput> {
pub fn convert_batches(batches: Vec<RecordBatch>) -> Result<DBOutput> {
if batches.is_empty() {
// DataFusion doesn't report number of rows complete
return Ok(DBOutput::StatementComplete(0));
Expand All @@ -53,23 +50,20 @@ pub fn convert_batches(
),
)));
}
rows.append(&mut convert_batch(batch, is_pg_compatibility_test)?);
rows.append(&mut convert_batch(batch)?);
}

Ok(DBOutput::Rows { types, rows })
}

/// Convert a single batch to a `Vec<Vec<String>>` for comparison
fn convert_batch(
batch: RecordBatch,
is_pg_compatibility_test: bool,
) -> Result<Vec<Vec<String>>> {
fn convert_batch(batch: RecordBatch) -> Result<Vec<Vec<String>>> {
(0..batch.num_rows())
.map(|row| {
batch
.columns()
.iter()
.map(|col| cell_to_string(col, row, is_pg_compatibility_test))
.map(|col| cell_to_string(col, row))
.collect::<Result<Vec<String>>>()
})
.collect()
Expand All @@ -93,18 +87,29 @@ macro_rules! get_row_value {
///
/// Floating numbers are rounded to have a consistent representation with the Postgres runner.
///
pub fn cell_to_string(
col: &ArrayRef,
row: usize,
is_pg_compatibility_test: bool,
) -> Result<String> {
pub fn cell_to_string(col: &ArrayRef, row: usize) -> Result<String> {
if !col.is_valid(row) {
// represent any null value with the string "NULL"
Ok(NULL_STR.to_string())
} else if is_pg_compatibility_test {
postgres_compatible_cell_to_string(col, row)
} else {
match col.data_type() {
DataType::Boolean => {
Ok(bool_to_str(get_row_value!(array::BooleanArray, col, row)))
}
DataType::Float16 => {
Ok(f16_to_str(get_row_value!(array::Float16Array, col, row)))
}
DataType::Float32 => {
Ok(f32_to_str(get_row_value!(array::Float32Array, col, row)))
}
DataType::Float64 => {
Ok(f64_to_str(get_row_value!(array::Float64Array, col, row)))
}
DataType::Decimal128(_, scale) => {
let value = get_row_value!(array::Decimal128Array, col, row);
let decimal_scale = u32::try_from((*scale).max(0)).unwrap();
Ok(i128_to_str(value, decimal_scale))
}
DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!(
array::LargeStringArray,
col,
Expand All @@ -118,36 +123,3 @@ pub fn cell_to_string(
.map_err(DFSqlLogicTestError::Arrow)
}
}

/// Convert values to text representation that are the same as in Postgres client implementation.
fn postgres_compatible_cell_to_string(col: &ArrayRef, row: usize) -> Result<String> {
match col.data_type() {
DataType::Boolean => {
Ok(bool_to_str(get_row_value!(array::BooleanArray, col, row)))
}
DataType::Float16 => {
Ok(f16_to_str(get_row_value!(array::Float16Array, col, row)))
}
DataType::Float32 => {
Ok(f32_to_str(get_row_value!(array::Float32Array, col, row)))
}
DataType::Float64 => {
Ok(f64_to_str(get_row_value!(array::Float64Array, col, row)))
}
DataType::Decimal128(_, scale) => {
let value = get_row_value!(array::Decimal128Array, col, row);
let decimal_scale = u32::try_from((*scale).max(0)).unwrap();
Ok(i128_to_str(value, decimal_scale))
}
DataType::LargeUtf8 => Ok(varchar_to_str(get_row_value!(
array::LargeStringArray,
col,
row
))),
DataType::Utf8 => {
Ok(varchar_to_str(get_row_value!(array::StringArray, col, row)))
}
_ => arrow::util::display::array_value_to_string(col, row),
}
.map_err(DFSqlLogicTestError::Arrow)
}
81 changes: 33 additions & 48 deletions datafusion/core/tests/sqllogictests/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use std::error::Error;
use std::path::{Path, PathBuf};

use log::{debug, info};
use log::info;
use testcontainers::clients::Cli as Docker;

use datafusion::prelude::SessionContext;
Expand Down Expand Up @@ -56,36 +56,23 @@ pub async fn main() -> Result<(), Box<dyn Error>> {

for path in files {
let file_name = path.file_name().unwrap().to_str().unwrap().to_string();
let is_pg_compatibility_test = file_name.starts_with(PG_COMPAT_FILE_PREFIX);

if options.complete_mode {
run_complete_file(&path, file_name, is_pg_compatibility_test).await?;
run_complete_file(&path, file_name).await?;
} else if options.postgres_runner {
if is_pg_compatibility_test {
run_test_file_with_postgres(&path, file_name).await?;
} else {
debug!("Skipping test file {:?}", path);
}
run_test_file_with_postgres(&path, file_name).await?;
} else {
run_test_file(&path, file_name, is_pg_compatibility_test).await?;
run_test_file(&path, file_name).await?;
}
}

Ok(())
}

async fn run_test_file(
path: &PathBuf,
file_name: String,
is_pg_compatibility_test: bool,
) -> Result<(), Box<dyn Error>> {
async fn run_test_file(path: &PathBuf, file_name: String) -> Result<(), Box<dyn Error>> {
println!("Running with DataFusion runner: {}", path.display());
let ctx = context_for_test_file(&file_name, is_pg_compatibility_test).await;
let mut runner = sqllogictest::Runner::new(DataFusion::new(
ctx,
file_name,
is_pg_compatibility_test,
));
let ctx = context_for_test_file(&file_name).await;
let mut runner = sqllogictest::Runner::new(DataFusion::new(ctx, file_name));
runner.run_file_async(path).await?;
Ok(())
}
Expand Down Expand Up @@ -117,18 +104,13 @@ async fn run_test_file_with_postgres(
async fn run_complete_file(
path: &PathBuf,
file_name: String,
is_pg_compatibility_test: bool,
) -> Result<(), Box<dyn Error>> {
use sqllogictest::{default_validator, update_test_file};

info!("Using complete mode to complete: {}", path.display());

let ctx = context_for_test_file(&file_name, is_pg_compatibility_test).await;
let mut runner = sqllogictest::Runner::new(DataFusion::new(
ctx,
file_name,
is_pg_compatibility_test,
));
let ctx = context_for_test_file(&file_name).await;
let mut runner = sqllogictest::Runner::new(DataFusion::new(ctx, file_name));

info!("Using complete mode to complete {}", path.display());
let col_separator = " ";
Expand All @@ -145,31 +127,28 @@ fn read_test_files(options: &Options) -> Vec<PathBuf> {
.unwrap()
.map(|path| path.unwrap().path())
.filter(|path| options.check_test_file(path.as_path()))
.filter(|path| options.check_pg_compat_file(path.as_path()))
.collect()
}

/// Create a SessionContext, configured for the specific test
async fn context_for_test_file(
file_name: &str,
is_pg_compatibility_test: bool,
) -> SessionContext {
if is_pg_compatibility_test {
info!("Registering pg compatibility tables");
let ctx = SessionContext::new();
setup::register_aggregate_csv_by_sql(&ctx).await;
ctx
} else {
match file_name {
"aggregate.slt" | "select.slt" => {
info!("Registering aggregate tables");
let ctx = SessionContext::new();
setup::register_aggregate_tables(&ctx).await;
ctx
}
_ => {
info!("Using default SessionContext");
SessionContext::new()
}
async fn context_for_test_file(file_name: &str) -> SessionContext {
match file_name {
"aggregate.slt" | "select.slt" => {
info!("Registering aggregate tables");
let ctx = SessionContext::new();
setup::register_aggregate_tables(&ctx).await;
ctx
}
_ if file_name.starts_with(PG_COMPAT_FILE_PREFIX) => {
info!("Registering pg compatibility tables");
let ctx = SessionContext::new();
setup::register_aggregate_csv_by_sql(&ctx).await;
ctx
}
_ => {
info!("Using default SessionContext");
SessionContext::new()
}
}
}
Expand Down Expand Up @@ -233,4 +212,10 @@ impl Options {
let path_str = path.to_string_lossy();
self.filters.iter().any(|filter| path_str.contains(filter))
}

/// Postgres runner executes only tests in files with specific names
fn check_pg_compat_file(&self, path: &Path) -> bool {
let file_name = path.file_name().unwrap().to_str().unwrap().to_string();
!self.postgres_runner || file_name.starts_with(PG_COMPAT_FILE_PREFIX)
}
}
Loading