Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions datafusion/common/src/scalar/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
//! [`ScalarValue`]: stores single values

mod struct_builder;

use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::{HashSet, VecDeque};
use std::convert::Infallible;
use std::fmt;
use std::hash::Hash;
use std::hash::Hasher;
use std::iter::repeat;
use std::str::FromStr;
use std::sync::Arc;
Expand Down Expand Up @@ -55,6 +55,7 @@ use arrow::{
use arrow_buffer::Buffer;
use arrow_schema::{UnionFields, UnionMode};

use half::f16;
pub use struct_builder::ScalarStructBuilder;

/// A dynamically typed, nullable single value.
Expand Down Expand Up @@ -192,6 +193,8 @@ pub enum ScalarValue {
Null,
/// true or false value
Boolean(Option<bool>),
/// 16bit float
Float16(Option<f16>),
/// 32bit float
Float32(Option<f32>),
/// 64bit float
Expand Down Expand Up @@ -285,6 +288,12 @@ pub enum ScalarValue {
Dictionary(Box<DataType>, Box<ScalarValue>),
}

impl Hash for Fl<f16> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.0.to_bits().hash(state);
}
}

// manual implementation of `PartialEq`
impl PartialEq for ScalarValue {
fn eq(&self, other: &Self) -> bool {
Expand All @@ -307,7 +316,12 @@ impl PartialEq for ScalarValue {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
},
(Float32(_), _) => false,
(Float16(_), _) => false,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(),
_ => v1.eq(v2),
Expand Down Expand Up @@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float16(v1), Float16(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
},
(Float32(_), _) => None,
(Float16(_), _) => None,
(Float64(v1), Float64(v2)) => match (v1, v2) {
(Some(f1), Some(f2)) => Some(f1.total_cmp(f2)),
_ => v1.partial_cmp(v2),
Expand Down Expand Up @@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue {
s.hash(state)
}
Boolean(v) => v.hash(state),
Float16(v) => v.map(Fl).hash(state),
Float32(v) => v.map(Fl).hash(state),
Float64(v) => v.map(Fl).hash(state),
Int8(v) => v.hash(state),
Expand Down Expand Up @@ -1082,6 +1102,7 @@ impl ScalarValue {
ScalarValue::TimestampNanosecond(_, tz_opt) => {
DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone())
}
ScalarValue::Float16(_) => DataType::Float16,
ScalarValue::Float32(_) => DataType::Float32,
ScalarValue::Float64(_) => DataType::Float64,
ScalarValue::Utf8(_) => DataType::Utf8,
Expand Down Expand Up @@ -1276,6 +1297,7 @@ impl ScalarValue {
match self {
ScalarValue::Boolean(v) => v.is_none(),
ScalarValue::Null => true,
ScalarValue::Float16(v) => v.is_none(),
ScalarValue::Float32(v) => v.is_none(),
ScalarValue::Float64(v) => v.is_none(),
ScalarValue::Decimal128(v, _, _) => v.is_none(),
Expand Down Expand Up @@ -1522,6 +1544,7 @@ impl ScalarValue {
}
DataType::Null => ScalarValue::iter_to_null_array(scalars)?,
DataType::Boolean => build_array_primitive!(BooleanArray, Boolean),
DataType::Float16 => build_array_primitive!(Float16Array, Float16),
DataType::Float32 => build_array_primitive!(Float32Array, Float32),
DataType::Float64 => build_array_primitive!(Float64Array, Float64),
DataType::Int8 => build_array_primitive!(Int8Array, Int8),
Expand Down Expand Up @@ -1682,8 +1705,7 @@ impl ScalarValue {
// not supported if the TimeUnit is not valid (Time32 can
// only be used with Second and Millisecond, Time64 only
// with Microsecond and Nanosecond)
DataType::Float16
| DataType::Time32(TimeUnit::Microsecond)
DataType::Time32(TimeUnit::Microsecond)
| DataType::Time32(TimeUnit::Nanosecond)
| DataType::Time64(TimeUnit::Second)
| DataType::Time64(TimeUnit::Millisecond)
Expand All @@ -1700,7 +1722,6 @@ impl ScalarValue {
);
}
};

Ok(array)
}

Expand Down Expand Up @@ -1921,6 +1942,9 @@ impl ScalarValue {
ScalarValue::Float32(e) => {
build_array_from_option!(Float32, Float32Array, e, size)
}
ScalarValue::Float16(e) => {
build_array_from_option!(Float16, Float16Array, e, size)
}
ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size),
ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size),
ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size),
Expand Down Expand Up @@ -2595,6 +2619,9 @@ impl ScalarValue {
ScalarValue::Boolean(val) => {
eq_array_primitive!(array, index, BooleanArray, val)?
}
ScalarValue::Float16(val) => {
eq_array_primitive!(array, index, Float16Array, val)?
}
ScalarValue::Float32(val) => {
eq_array_primitive!(array, index, Float32Array, val)?
}
Expand Down Expand Up @@ -2738,6 +2765,7 @@ impl ScalarValue {
+ match self {
ScalarValue::Null
| ScalarValue::Boolean(_)
| ScalarValue::Float16(_)
| ScalarValue::Float32(_)
| ScalarValue::Float64(_)
| ScalarValue::Decimal128(_, _, _)
Expand Down Expand Up @@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue {
fn try_from(data_type: &DataType) -> Result<Self> {
Ok(match data_type {
DataType::Boolean => ScalarValue::Boolean(None),
DataType::Float16 => ScalarValue::Float16(None),
DataType::Float64 => ScalarValue::Float64(None),
DataType::Float32 => ScalarValue::Float32(None),
DataType::Int8 => ScalarValue::Int8(None),
Expand Down Expand Up @@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue {
write!(f, "{v:?},{p:?},{s:?}")?;
}
ScalarValue::Boolean(e) => format_option!(f, e)?,
ScalarValue::Float16(e) => format_option!(f, e)?,
ScalarValue::Float32(e) => format_option!(f, e)?,
ScalarValue::Float64(e) => format_option!(f, e)?,
ScalarValue::Int8(e) => format_option!(f, e)?,
Expand Down Expand Up @@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue {
ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"),
ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"),
ScalarValue::Boolean(_) => write!(f, "Boolean({self})"),
ScalarValue::Float16(_) => write!(f, "Float16({self})"),
ScalarValue::Float32(_) => write!(f, "Float32({self})"),
ScalarValue::Float64(_) => write!(f, "Float64({self})"),
ScalarValue::Int8(_) => write!(f, "Int8({self})"),
Expand Down
14 changes: 12 additions & 2 deletions datafusion/core/src/datasource/physical_plan/parquet/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema};
use datafusion_common::{
internal_datafusion_err, internal_err, plan_err, Result, ScalarValue,
};
use half::f16;
use parquet::file::metadata::ParquetMetaData;
use parquet::file::statistics::Statistics as ParquetStatistics;
use parquet::schema::types::SchemaDescriptor;
use std::sync::Arc;

// Convert the bytes array to i128.
// The endian of the input bytes array must be big-endian.
pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
Expand All @@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 {
i128::from_be_bytes(sign_extend_be(b))
}

// Convert the bytes array to f16
pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option<f16> {
match b {
[low, high] => Some(f16::from_be_bytes([*high, *low])),
_ => None,
}
}

// Copy from arrow-rs
// https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55
// Convert the byte slice to fixed length byte array with the length of 16
Expand Down Expand Up @@ -196,6 +204,9 @@ macro_rules! get_statistic {
value,
))
}
Some(DataType::Float16) => {
Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func())))
}
_ => None,
}
}
Expand Down Expand Up @@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> {
column_name
);
};

Ok(Self {
column_name,
statistics_type,
Expand Down
43 changes: 37 additions & 6 deletions datafusion/core/tests/parquet/arrow_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,29 @@
use std::fs::File;
use std::sync::Arc;

use crate::parquet::{struct_array, Scenario};
use arrow::compute::kernels::cast_utils::Parser;
use arrow::datatypes::{
Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType,
TimestampNanosecondType, TimestampSecondType,
};
use arrow_array::{
make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array,
Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch,
StringArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
};
use arrow_schema::{DataType, Field, Schema};
use datafusion::datasource::physical_plan::parquet::{
RequestedStatistics, StatisticsConverter,
};
use half::f16;
use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder};
use parquet::arrow::ArrowWriter;
use parquet::file::properties::{EnabledStatistics, WriterProperties};

use crate::parquet::{struct_array, Scenario};

use super::make_test_file_rg;

// TEST HELPERS
Expand Down Expand Up @@ -1203,6 +1204,36 @@ async fn test_float64() {
.run();
}

#[tokio::test]
async fn test_float16() {
// This creates a parquet file of 1 column "f"
// file has 4 record batches, each has 5 rows. They will be saved into 4 row groups
let reader = TestReader {
scenario: Scenario::Float16,
row_per_group: 5,
};

Test {
reader: reader.build().await,
expected_min: Arc::new(Float16Array::from(
vec![-5.0, -4.0, -0.0, 5.0]
.into_iter()
.map(f16::from_f32)
.collect::<Vec<_>>(),
)),
expected_max: Arc::new(Float16Array::from(
vec![-1.0, 0.0, 4.0, 9.0]
.into_iter()
.map(f16::from_f32)
.collect::<Vec<_>>(),
)),
expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]),
expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]),
column_name: "f",
}
.run();
}

#[tokio::test]
async fn test_decimal() {
// This creates a parquet file of 1 column "decimal_col" with decimal data type and precicion 9, scale 2
Expand Down
55 changes: 43 additions & 12 deletions datafusion/core/tests/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,32 +19,29 @@
use arrow::array::Decimal128Array;
use arrow::{
array::{
Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray,
Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array,
DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
StructArray, TimestampMicrosecondArray, TimestampMillisecondArray,
TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array,
UInt64Array, UInt8Array,
},
datatypes::{DataType, Field, Schema},
datatypes::{DataType, Field, Int32Type, Int8Type, Schema},
record_batch::RecordBatch,
util::pretty::pretty_format_batches,
};
use arrow_array::types::{Int32Type, Int8Type};
use arrow_array::{
make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray,
StructArray,
};
use chrono::{Datelike, Duration, TimeDelta};
use datafusion::{
datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider},
physical_plan::{accept, metrics::MetricsSet, ExecutionPlan, ExecutionPlanVisitor},
prelude::{ParquetReadOptions, SessionConfig, SessionContext},
};
use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder};
use half::f16;
use parquet::arrow::ArrowWriter;
use parquet::file::properties::WriterProperties;
use std::sync::Arc;
use tempfile::NamedTempFile;

mod arrow_statistics;
mod custom_reader;
mod file_statistics;
Expand Down Expand Up @@ -79,6 +76,7 @@ enum Scenario {
/// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64
/// -MIN, -100, -1, 0, 1, 100, MAX
NumericLimits,
Float16,
Float64,
Decimal,
DecimalBloomFilterInt32,
Expand Down Expand Up @@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec<f64>) -> RecordBatch {
RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
}

fn make_f16_batch(v: Vec<f16>) -> RecordBatch {
let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, true)]));
let array = Arc::new(Float16Array::from(v)) as ArrayRef;
RecordBatch::try_new(schema, vec![array.clone()]).unwrap()
}

/// Return record batch with decimal vector
///
/// Columns are named
Expand Down Expand Up @@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) -> Vec<RecordBatch> {
Scenario::NumericLimits => {
vec![make_numeric_limit_batch()]
}
Scenario::Float16 => {
vec![
make_f16_batch(
vec![-5.0, -4.0, -3.0, -2.0, -1.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![-4.0, -3.0, -2.0, -1.0, 0.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![0.0, 1.0, 2.0, 3.0, 4.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
make_f16_batch(
vec![5.0, 6.0, 7.0, 8.0, 9.0]
.into_iter()
.map(f16::from_f32)
.collect(),
),
]
}
Scenario::Float64 => {
vec![
make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]),
Expand Down Expand Up @@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem
.build();

let batches = create_data_batch(scenario);

let schema = batches[0].schema();

let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap();
Expand Down
Loading