Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 153 additions & 74 deletions rust/arrow/src/compute/kernels/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@ pub fn can_cast_types(from_type: &DataType, to_type: &DataType) -> bool {
(Utf8, Date64) => true,
(Utf8, Timestamp(TimeUnit::Nanosecond, None)) => true,
(Utf8, _) => DataType::is_numeric(to_type),
(LargeUtf8, Date32) => true,
(LargeUtf8, Date64) => true,
(LargeUtf8, Timestamp(TimeUnit::Nanosecond, None)) => true,
(LargeUtf8, _) => DataType::is_numeric(to_type),
(_, Utf8) | (_, LargeUtf8) => {
DataType::is_numeric(from_type) || from_type == &Binary
}
Expand Down Expand Up @@ -366,66 +370,20 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
},
(Utf8, _) => match to_type {
LargeUtf8 => cast_str_container::<i32, i64>(&**array),
UInt8 => cast_string_to_numeric::<UInt8Type>(array),
UInt16 => cast_string_to_numeric::<UInt16Type>(array),
UInt32 => cast_string_to_numeric::<UInt32Type>(array),
UInt64 => cast_string_to_numeric::<UInt64Type>(array),
Int8 => cast_string_to_numeric::<Int8Type>(array),
Int16 => cast_string_to_numeric::<Int16Type>(array),
Int32 => cast_string_to_numeric::<Int32Type>(array),
Int64 => cast_string_to_numeric::<Int64Type>(array),
Float32 => cast_string_to_numeric::<Float32Type>(array),
Float64 => cast_string_to_numeric::<Float64Type>(array),
Date32 => {
use chrono::Datelike;
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = PrimitiveBuilder::<Date32Type>::new(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
builder.append_null()?;
} else {
match string_array.value(i).parse::<chrono::NaiveDate>() {
Ok(date) => builder.append_value(
date.num_days_from_ce() - EPOCH_DAYS_FROM_CE,
)?,
Err(_) => builder.append_null()?, // not a valid date
};
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
Date64 => {
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder = PrimitiveBuilder::<Date64Type>::new(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
builder.append_null()?;
} else {
match string_array.value(i).parse::<chrono::NaiveDateTime>() {
Ok(date_time) => {
builder.append_value(date_time.timestamp_millis())?
}
Err(_) => builder.append_null()?, // not a valid date
};
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
}
UInt8 => cast_string_to_numeric::<UInt8Type, i32>(array),
UInt16 => cast_string_to_numeric::<UInt16Type, i32>(array),
UInt32 => cast_string_to_numeric::<UInt32Type, i32>(array),
UInt64 => cast_string_to_numeric::<UInt64Type, i32>(array),
Int8 => cast_string_to_numeric::<Int8Type, i32>(array),
Int16 => cast_string_to_numeric::<Int16Type, i32>(array),
Int32 => cast_string_to_numeric::<Int32Type, i32>(array),
Int64 => cast_string_to_numeric::<Int64Type, i32>(array),
Float32 => cast_string_to_numeric::<Float32Type, i32>(array),
Float64 => cast_string_to_numeric::<Float64Type, i32>(array),
Date32 => cast_string_to_date32::<i32>(&**array),
Date64 => cast_string_to_date64::<i32>(&**array),
Timestamp(TimeUnit::Nanosecond, None) => {
let string_array = array.as_any().downcast_ref::<StringArray>().unwrap();
let mut builder =
PrimitiveBuilder::<TimestampNanosecondType>::new(string_array.len());
for i in 0..string_array.len() {
if string_array.is_null(i) {
builder.append_null()?;
} else {
match string_to_timestamp_nanos(string_array.value(i)) {
Ok(nanos) => builder.append_value(nanos)?,
Err(_) => builder.append_null()?, // not a valid date
};
}
}
Ok(Arc::new(builder.finish()) as ArrayRef)
cast_string_to_timestamp_ns::<i32>(&**array)
}
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
Expand Down Expand Up @@ -487,6 +445,27 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result<ArrayRef> {
from_type, to_type,
))),
},
(LargeUtf8, _) => match to_type {
UInt8 => cast_string_to_numeric::<UInt8Type, i64>(array),
UInt16 => cast_string_to_numeric::<UInt16Type, i64>(array),
UInt32 => cast_string_to_numeric::<UInt32Type, i64>(array),
UInt64 => cast_string_to_numeric::<UInt64Type, i64>(array),
Int8 => cast_string_to_numeric::<Int8Type, i64>(array),
Int16 => cast_string_to_numeric::<Int16Type, i64>(array),
Int32 => cast_string_to_numeric::<Int32Type, i64>(array),
Int64 => cast_string_to_numeric::<Int64Type, i64>(array),
Float32 => cast_string_to_numeric::<Float32Type, i64>(array),
Float64 => cast_string_to_numeric::<Float64Type, i64>(array),
Date32 => cast_string_to_date32::<i64>(&**array),
Date64 => cast_string_to_date64::<i64>(&**array),
Timestamp(TimeUnit::Nanosecond, None) => {
cast_string_to_timestamp_ns::<i64>(&**array)
}
_ => Err(ArrowError::ComputeError(format!(
"Casting from {:?} to {:?} not supported",
from_type, to_type,
))),
},

// start numeric casts
(UInt8, UInt16) => cast_numeric_arrays::<UInt8Type, UInt16Type>(array),
Expand Down Expand Up @@ -949,17 +928,23 @@ where

/// Cast numeric types to Utf8
#[allow(clippy::unnecessary_wraps)]
fn cast_string_to_numeric<T>(from: &ArrayRef) -> Result<ArrayRef>
fn cast_string_to_numeric<T, Offset: StringOffsetSizeTrait>(
from: &ArrayRef,
) -> Result<ArrayRef>
where
T: ArrowNumericType,
<T as ArrowPrimitiveType>::Native: lexical_core::FromLexical,
{
Ok(Arc::new(string_to_numeric_cast::<T>(
from.as_any().downcast_ref::<StringArray>().unwrap(),
Ok(Arc::new(string_to_numeric_cast::<T, Offset>(
from.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap(),
)))
}

fn string_to_numeric_cast<T>(from: &StringArray) -> PrimitiveArray<T>
fn string_to_numeric_cast<T, Offset: StringOffsetSizeTrait>(
from: &GenericStringArray<Offset>,
) -> PrimitiveArray<T>
where
T: ArrowNumericType,
<T as ArrowPrimitiveType>::Native: lexical_core::FromLexical,
Expand All @@ -978,6 +963,93 @@ where
unsafe { PrimitiveArray::<T>::from_trusted_len_iter(iter) }
}

/// Casts generic string arrays to Date32Array
#[allow(clippy::unnecessary_wraps)]
fn cast_string_to_date32<Offset: StringOffsetSizeTrait>(
array: &dyn Array,
) -> Result<ArrayRef> {
use chrono::Datelike;
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();

let iter = (0..string_array.len()).map(|i| {
if string_array.is_null(i) {
None
} else {
string_array
.value(i)
.parse::<chrono::NaiveDate>()
.map(|date| date.num_days_from_ce() - EPOCH_DAYS_FROM_CE)
.ok()
}
});

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
let array = unsafe { Date32Array::from_trusted_len_iter(iter) };
Ok(Arc::new(array) as ArrayRef)
}

/// Casts generic string arrays to Date64Array
#[allow(clippy::unnecessary_wraps)]
fn cast_string_to_date64<Offset: StringOffsetSizeTrait>(
array: &dyn Array,
) -> Result<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();

let iter = (0..string_array.len()).map(|i| {
if string_array.is_null(i) {
None
} else {
string_array
.value(i)
.parse::<chrono::NaiveDateTime>()
.map(|datetime| datetime.timestamp_millis())
.ok()
}
});

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
let array = unsafe { Date64Array::from_trusted_len_iter(iter) };
Ok(Arc::new(array) as ArrayRef)
}

/// Casts generic string arrays to TimeStampNanosecondArray
#[allow(clippy::unnecessary_wraps)]
fn cast_string_to_timestamp_ns<Offset: StringOffsetSizeTrait>(
array: &dyn Array,
) -> Result<ArrayRef> {
let string_array = array
.as_any()
.downcast_ref::<GenericStringArray<Offset>>()
.unwrap();

let iter = (0..string_array.len()).map(|i| {
if string_array.is_null(i) {
None
} else {
string_to_timestamp_nanos(string_array.value(i)).ok()
}
});

// Benefit:
// 20% performance improvement
// Soundness:
// The iterator is trustedLen because it comes from an `StringArray`.
let array = unsafe { TimestampNanosecondArray::from_trusted_len_iter(iter) };
Ok(Arc::new(array) as ArrayRef)
}

/// Cast numeric types to Boolean
///
/// Any zero value returns `false` while non-zero returns `true`
Expand Down Expand Up @@ -1719,20 +1791,27 @@ mod tests {

#[test]
fn test_cast_string_to_timestamp() {
let a = StringArray::from(vec![
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I recommend adding coverage here to convert from LargeUtf8 string arrays as well - I don't think they are covered by these tests

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nevi-me I just fixed this test. I already modified the test with that intention but I apparently tested stringarray twice instead of stringarray and largestringarray

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @ritchie46 !

let a1 = Arc::new(StringArray::from(vec![
Some("2020-09-08T12:00:00+00:00"),
Some("Not a valid date"),
None,
]);
let array = Arc::new(a) as ArrayRef;
let b = cast(&array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap();
let c = b
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
assert_eq!(1599566400000000000, c.value(0));
assert!(c.is_null(1));
assert!(c.is_null(2));
])) as ArrayRef;
let a2 = Arc::new(LargeStringArray::from(vec![
Some("2020-09-08T12:00:00+00:00"),
Some("Not a valid date"),
None,
])) as ArrayRef;
for array in &[a1, a2] {
let b =
cast(array, &DataType::Timestamp(TimeUnit::Nanosecond, None)).unwrap();
let c = b
.as_any()
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
assert_eq!(1599566400000000000, c.value(0));
assert!(c.is_null(1));
assert!(c.is_null(2));
}
}

#[test]
Expand Down