diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d2c6513eef957..ba006247cd708 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -18,13 +18,13 @@ //! [`ScalarValue`]: stores single values mod struct_builder; - use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::{HashSet, VecDeque}; use std::convert::Infallible; use std::fmt; use std::hash::Hash; +use std::hash::Hasher; use std::iter::repeat; use std::str::FromStr; use std::sync::Arc; @@ -55,6 +55,7 @@ use arrow::{ use arrow_buffer::Buffer; use arrow_schema::{UnionFields, UnionMode}; +use half::f16; pub use struct_builder::ScalarStructBuilder; /// A dynamically typed, nullable single value. @@ -192,6 +193,8 @@ pub enum ScalarValue { Null, /// true or false value Boolean(Option), + /// 16bit float + Float16(Option), /// 32bit float Float32(Option), /// 64bit float @@ -285,6 +288,12 @@ pub enum ScalarValue { Dictionary(Box, Box), } +impl Hash for Fl { + fn hash(&self, state: &mut H) { + self.0.to_bits().hash(state); + } +} + // manual implementation of `PartialEq` impl PartialEq for ScalarValue { fn eq(&self, other: &Self) -> bool { @@ -307,7 +316,12 @@ impl PartialEq for ScalarValue { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), + _ => v1.eq(v2), + }, (Float32(_), _) => false, + (Float16(_), _) => false, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => f1.to_bits() == f2.to_bits(), _ => v1.eq(v2), @@ -425,7 +439,12 @@ impl PartialOrd for ScalarValue { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), }, + (Float16(v1), Float16(v2)) => match (v1, v2) { + (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), + _ => v1.partial_cmp(v2), + }, (Float32(_), _) => None, + (Float16(_), _) => None, (Float64(v1), Float64(v2)) => match (v1, v2) { (Some(f1), Some(f2)) => Some(f1.total_cmp(f2)), _ => v1.partial_cmp(v2), @@ -637,6 +656,7 @@ impl std::hash::Hash for ScalarValue { s.hash(state) } Boolean(v) => v.hash(state), + Float16(v) => v.map(Fl).hash(state), Float32(v) => v.map(Fl).hash(state), Float64(v) => v.map(Fl).hash(state), Int8(v) => v.hash(state), @@ -1082,6 +1102,7 @@ impl ScalarValue { ScalarValue::TimestampNanosecond(_, tz_opt) => { DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()) } + ScalarValue::Float16(_) => DataType::Float16, ScalarValue::Float32(_) => DataType::Float32, ScalarValue::Float64(_) => DataType::Float64, ScalarValue::Utf8(_) => DataType::Utf8, @@ -1276,6 +1297,7 @@ impl ScalarValue { match self { ScalarValue::Boolean(v) => v.is_none(), ScalarValue::Null => true, + ScalarValue::Float16(v) => v.is_none(), ScalarValue::Float32(v) => v.is_none(), ScalarValue::Float64(v) => v.is_none(), ScalarValue::Decimal128(v, _, _) => v.is_none(), @@ -1522,6 +1544,7 @@ impl ScalarValue { } DataType::Null => ScalarValue::iter_to_null_array(scalars)?, DataType::Boolean => build_array_primitive!(BooleanArray, Boolean), + DataType::Float16 => build_array_primitive!(Float16Array, Float16), DataType::Float32 => build_array_primitive!(Float32Array, Float32), DataType::Float64 => build_array_primitive!(Float64Array, Float64), DataType::Int8 => build_array_primitive!(Int8Array, Int8), @@ -1682,8 +1705,7 @@ impl ScalarValue { // not supported if the TimeUnit is not valid (Time32 can // only be used with Second and Millisecond, Time64 only // with Microsecond and Nanosecond) - DataType::Float16 - | DataType::Time32(TimeUnit::Microsecond) + DataType::Time32(TimeUnit::Microsecond) | DataType::Time32(TimeUnit::Nanosecond) | DataType::Time64(TimeUnit::Second) | DataType::Time64(TimeUnit::Millisecond) @@ -1700,7 +1722,6 @@ impl ScalarValue { ); } }; - Ok(array) } @@ -1921,6 +1942,9 @@ impl ScalarValue { ScalarValue::Float32(e) => { build_array_from_option!(Float32, Float32Array, e, size) } + ScalarValue::Float16(e) => { + build_array_from_option!(Float16, Float16Array, e, size) + } ScalarValue::Int8(e) => build_array_from_option!(Int8, Int8Array, e, size), ScalarValue::Int16(e) => build_array_from_option!(Int16, Int16Array, e, size), ScalarValue::Int32(e) => build_array_from_option!(Int32, Int32Array, e, size), @@ -2595,6 +2619,9 @@ impl ScalarValue { ScalarValue::Boolean(val) => { eq_array_primitive!(array, index, BooleanArray, val)? } + ScalarValue::Float16(val) => { + eq_array_primitive!(array, index, Float16Array, val)? + } ScalarValue::Float32(val) => { eq_array_primitive!(array, index, Float32Array, val)? } @@ -2738,6 +2765,7 @@ impl ScalarValue { + match self { ScalarValue::Null | ScalarValue::Boolean(_) + | ScalarValue::Float16(_) | ScalarValue::Float32(_) | ScalarValue::Float64(_) | ScalarValue::Decimal128(_, _, _) @@ -3022,6 +3050,7 @@ impl TryFrom<&DataType> for ScalarValue { fn try_from(data_type: &DataType) -> Result { Ok(match data_type { DataType::Boolean => ScalarValue::Boolean(None), + DataType::Float16 => ScalarValue::Float16(None), DataType::Float64 => ScalarValue::Float64(None), DataType::Float32 => ScalarValue::Float32(None), DataType::Int8 => ScalarValue::Int8(None), @@ -3147,6 +3176,7 @@ impl fmt::Display for ScalarValue { write!(f, "{v:?},{p:?},{s:?}")?; } ScalarValue::Boolean(e) => format_option!(f, e)?, + ScalarValue::Float16(e) => format_option!(f, e)?, ScalarValue::Float32(e) => format_option!(f, e)?, ScalarValue::Float64(e) => format_option!(f, e)?, ScalarValue::Int8(e) => format_option!(f, e)?, @@ -3260,6 +3290,7 @@ impl fmt::Debug for ScalarValue { ScalarValue::Decimal128(_, _, _) => write!(f, "Decimal128({self})"), ScalarValue::Decimal256(_, _, _) => write!(f, "Decimal256({self})"), ScalarValue::Boolean(_) => write!(f, "Boolean({self})"), + ScalarValue::Float16(_) => write!(f, "Float16({self})"), ScalarValue::Float32(_) => write!(f, "Float32({self})"), ScalarValue::Float64(_) => write!(f, "Float64({self})"), ScalarValue::Int8(_) => write!(f, "Int8({self})"), diff --git a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs index e7e6360c2500d..6c738cfe03a95 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/statistics.rs @@ -25,11 +25,11 @@ use arrow_schema::{Field, FieldRef, Schema}; use datafusion_common::{ internal_datafusion_err, internal_err, plan_err, Result, ScalarValue, }; +use half::f16; use parquet::file::metadata::ParquetMetaData; use parquet::file::statistics::Statistics as ParquetStatistics; use parquet::schema::types::SchemaDescriptor; use std::sync::Arc; - // Convert the bytes array to i128. // The endian of the input bytes array must be big-endian. pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { @@ -39,6 +39,14 @@ pub(crate) fn from_bytes_to_i128(b: &[u8]) -> i128 { i128::from_be_bytes(sign_extend_be(b)) } +// Convert the bytes array to f16 +pub(crate) fn from_bytes_to_f16(b: &[u8]) -> Option { + match b { + [low, high] => Some(f16::from_be_bytes([*high, *low])), + _ => None, + } +} + // Copy from arrow-rs // https://github.com/apache/arrow-rs/blob/733b7e7fd1e8c43a404c3ce40ecf741d493c21b4/parquet/src/arrow/buffer/bit_util.rs#L55 // Convert the byte slice to fixed length byte array with the length of 16 @@ -196,6 +204,9 @@ macro_rules! get_statistic { value, )) } + Some(DataType::Float16) => { + Some(ScalarValue::Float16(from_bytes_to_f16(s.$bytes_func()))) + } _ => None, } } @@ -344,7 +355,6 @@ impl<'a> StatisticsConverter<'a> { column_name ); }; - Ok(Self { column_name, statistics_type, diff --git a/datafusion/core/tests/parquet/arrow_statistics.rs b/datafusion/core/tests/parquet/arrow_statistics.rs index aa5fc7c34c481..c2bf75c8f0896 100644 --- a/datafusion/core/tests/parquet/arrow_statistics.rs +++ b/datafusion/core/tests/parquet/arrow_statistics.rs @@ -21,6 +21,7 @@ use std::fs::File; use std::sync::Arc; +use crate::parquet::{struct_array, Scenario}; use arrow::compute::kernels::cast_utils::Parser; use arrow::datatypes::{ Date32Type, Date64Type, TimestampMicrosecondType, TimestampMillisecondType, @@ -28,21 +29,21 @@ use arrow::datatypes::{ }; use arrow_array::{ make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, - Decimal128Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int16Array, - Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, RecordBatch, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }; use arrow_schema::{DataType, Field, Schema}; use datafusion::datasource::physical_plan::parquet::{ RequestedStatistics, StatisticsConverter, }; +use half::f16; use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReaderBuilder}; use parquet::arrow::ArrowWriter; use parquet::file::properties::{EnabledStatistics, WriterProperties}; -use crate::parquet::{struct_array, Scenario}; - use super::make_test_file_rg; // TEST HELPERS @@ -1203,6 +1204,36 @@ async fn test_float64() { .run(); } +#[tokio::test] +async fn test_float16() { + // This creates a parquet file of 1 column "f" + // file has 4 record batches, each has 5 rows. They will be saved into 4 row groups + let reader = TestReader { + scenario: Scenario::Float16, + row_per_group: 5, + }; + + Test { + reader: reader.build().await, + expected_min: Arc::new(Float16Array::from( + vec![-5.0, -4.0, -0.0, 5.0] + .into_iter() + .map(f16::from_f32) + .collect::>(), + )), + expected_max: Arc::new(Float16Array::from( + vec![-1.0, 0.0, 4.0, 9.0] + .into_iter() + .map(f16::from_f32) + .collect::>(), + )), + expected_null_counts: UInt64Array::from(vec![0, 0, 0, 0]), + expected_row_counts: UInt64Array::from(vec![5, 5, 5, 5]), + column_name: "f", + } + .run(); +} + #[tokio::test] async fn test_decimal() { // This creates a parquet file of 1 column "decimal_col" with decimal data type and precicion 9, scale 2 diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index bfb6e8e555c93..e951644f2cbfd 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -19,20 +19,17 @@ use arrow::array::Decimal128Array; use arrow::{ array::{ - Array, ArrayRef, BinaryArray, Date32Array, Date64Array, FixedSizeBinaryArray, - Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + make_array, Array, ArrayRef, BinaryArray, BooleanArray, Date32Array, Date64Array, + DictionaryArray, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array, + Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray, + StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, - datatypes::{DataType, Field, Schema}, + datatypes::{DataType, Field, Int32Type, Int8Type, Schema}, record_batch::RecordBatch, util::pretty::pretty_format_batches, }; -use arrow_array::types::{Int32Type, Int8Type}; -use arrow_array::{ - make_array, BooleanArray, DictionaryArray, Float32Array, LargeStringArray, - StructArray, -}; use chrono::{Datelike, Duration, TimeDelta}; use datafusion::{ datasource::{physical_plan::ParquetExec, provider_as_source, TableProvider}, @@ -40,11 +37,11 @@ use datafusion::{ prelude::{ParquetReadOptions, SessionConfig, SessionContext}, }; use datafusion_expr::{Expr, LogicalPlan, LogicalPlanBuilder}; +use half::f16; use parquet::arrow::ArrowWriter; use parquet::file::properties::WriterProperties; use std::sync::Arc; use tempfile::NamedTempFile; - mod arrow_statistics; mod custom_reader; mod file_statistics; @@ -79,6 +76,7 @@ enum Scenario { /// 7 Rows, for each i8, i16, i32, i64, u8, u16, u32, u64, f32, f64 /// -MIN, -100, -1, 0, 1, 100, MAX NumericLimits, + Float16, Float64, Decimal, DecimalBloomFilterInt32, @@ -542,6 +540,12 @@ fn make_f64_batch(v: Vec) -> RecordBatch { RecordBatch::try_new(schema, vec![array.clone()]).unwrap() } +fn make_f16_batch(v: Vec) -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("f", DataType::Float16, true)])); + let array = Arc::new(Float16Array::from(v)) as ArrayRef; + RecordBatch::try_new(schema, vec![array.clone()]).unwrap() +} + /// Return record batch with decimal vector /// /// Columns are named @@ -897,6 +901,34 @@ fn create_data_batch(scenario: Scenario) -> Vec { Scenario::NumericLimits => { vec![make_numeric_limit_batch()] } + Scenario::Float16 => { + vec![ + make_f16_batch( + vec![-5.0, -4.0, -3.0, -2.0, -1.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![-4.0, -3.0, -2.0, -1.0, 0.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![0.0, 1.0, 2.0, 3.0, 4.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + make_f16_batch( + vec![5.0, 6.0, 7.0, 8.0, 9.0] + .into_iter() + .map(f16::from_f32) + .collect(), + ), + ] + } Scenario::Float64 => { vec![ make_f64_batch(vec![-5.0, -4.0, -3.0, -2.0, -1.0]), @@ -1087,7 +1119,6 @@ async fn make_test_file_rg(scenario: Scenario, row_per_group: usize) -> NamedTem .build(); let batches = create_data_batch(scenario); - let schema = batches[0].schema(); let mut writer = ArrowWriter::try_new(&mut output_file, schema, Some(props)).unwrap(); diff --git a/datafusion/proto-common/src/to_proto/mod.rs b/datafusion/proto-common/src/to_proto/mod.rs index f160bc40af396..a92deaa88b1ca 100644 --- a/datafusion/proto-common/src/to_proto/mod.rs +++ b/datafusion/proto-common/src/to_proto/mod.rs @@ -294,6 +294,11 @@ impl TryFrom<&ScalarValue> for protobuf::ScalarValue { ScalarValue::Boolean(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::BoolValue(*s)) } + ScalarValue::Float16(val) => { + create_proto_scalar(val.as_ref(), &data_type, |s| { + Value::Float32Value((*s).into()) + }) + } ScalarValue::Float32(val) => { create_proto_scalar(val.as_ref(), &data_type, |s| Value::Float32Value(*s)) } diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 1ba6638e73d7b..3efbe2ace680d 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -643,6 +643,10 @@ impl Unparser<'_> { Ok(ast::Expr::Value(ast::Value::Boolean(b.to_owned()))) } ScalarValue::Boolean(None) => Ok(ast::Expr::Value(ast::Value::Null)), + ScalarValue::Float16(Some(f)) => { + Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) + } + ScalarValue::Float16(None) => Ok(ast::Expr::Value(ast::Value::Null)), ScalarValue::Float32(Some(f)) => { Ok(ast::Expr::Value(ast::Value::Number(f.to_string(), false))) }