diff --git a/rust/arrow/src/array/builder.rs b/rust/arrow/src/array/builder.rs index 30cce75d00c..06cb5d59723 100644 --- a/rust/arrow/src/array/builder.rs +++ b/rust/arrow/src/array/builder.rs @@ -952,6 +952,7 @@ pub struct DecimalBuilder { builder: FixedSizeListBuilder, precision: usize, scale: usize, + max_value: i128, } impl ArrayBuilder @@ -1221,10 +1222,12 @@ impl DecimalBuilder { pub fn new(capacity: usize, precision: usize, scale: usize) -> Self { let values_builder = UInt8Builder::new(capacity); let byte_width = 16; + let max_value = i128::pow(10, precision as u32) - 1; Self { builder: FixedSizeListBuilder::new(values_builder, byte_width), precision, scale, + max_value, } } @@ -1233,6 +1236,12 @@ impl DecimalBuilder { /// Automatically calls the `append` method to delimit the slice appended in as a /// distinct array element. pub fn append_value(&mut self, value: i128) -> Result<()> { + if value > self.max_value || value < -self.max_value { + return Err(ArrowError::InvalidArgumentError(format!( + "Value {} does not fit in decimal with precision {}", + value, self.precision + ))); + } let value_as_bytes = Self::from_i128_to_fixed_size_bytes( value, self.builder.value_length() as usize, @@ -2772,6 +2781,21 @@ mod tests { assert_eq!(16, decimal_array.value_length()); } + #[test] + fn test_decimal_builder_fails_for_values_beyond_precision() { + let mut builder = DecimalBuilder::new(30, 5, 2); + + builder.append_value(99999).unwrap(); + assert!(builder.append_value(100000).is_err()); + builder.append_value(-99999).unwrap(); + assert!(builder.append_value(-100000).is_err()); + let decimal_array: DecimalArray = builder.finish(); + + assert_eq!(&DataType::Decimal(5, 2), decimal_array.data_type()); + assert_eq!(2, decimal_array.len()); + assert_eq!(0, decimal_array.null_count()); + } + #[test] fn test_string_array_builder_finish() { let mut builder = StringBuilder::new(10); diff --git a/rust/arrow/src/compute/kernels/cast.rs b/rust/arrow/src/compute/kernels/cast.rs index 40b33fc649d..ca20f618e97 100644 --- a/rust/arrow/src/compute/kernels/cast.rs +++ b/rust/arrow/src/compute/kernels/cast.rs @@ -38,6 +38,8 @@ use std::str; use std::sync::Arc; +use num::ToPrimitive; + use crate::compute::kernels::arithmetic::{divide, multiply}; use crate::datatypes::*; use crate::error::{ArrowError, Result}; @@ -794,6 +796,30 @@ pub fn cast(array: &ArrayRef, to_type: &DataType) -> Result { } } } + (Int16, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (Int32, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (Int64, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (UInt16, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (UInt32, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (UInt64, Decimal(precision, scale)) => { + cast_int_to_decimal::(array, *precision, *scale) + } + (Float32, Decimal(precision, scale)) => { + cast_float_to_decimal::(array, *precision, *scale) + } + (Float64, Decimal(precision, scale)) => { + cast_float_to_decimal::(array, *precision, *scale) + } // null to primitive/flat types (Null, Int32) => Ok(Arc::new(Int32Array::from(vec![None; array.len()]))), @@ -1187,6 +1213,62 @@ where Ok(Arc::new(b.finish())) } +fn cast_int_to_decimal( + array: &ArrayRef, + precision: usize, + scale: usize, +) -> Result +where + TO: ArrowNumericType, + TO::Native: num::ToPrimitive, +{ + let values = array.as_any().downcast_ref::>().unwrap(); + + let mut builder = DecimalBuilder::new(values.len(), precision, scale); + let scaling = i128::pow(10, scale as u32); + + for maybe_value in values.iter() { + match maybe_value { + Some(v) => { + let v_as_int = v.to_i128().ok_or_else(|| { + ArrowError::ComputeError(format!("Expected integer but got {:?}", v)) + })?; + builder.append_value(v_as_int * scaling)? + } + None => builder.append_null()?, + }; + } + Ok(Arc::new(builder.finish())) +} + +fn cast_float_to_decimal( + array: &ArrayRef, + precision: usize, + scale: usize, +) -> Result +where + TO: ArrowNumericType, + TO::Native: num::ToPrimitive, +{ + let values = array.as_any().downcast_ref::>().unwrap(); + + let mut builder = DecimalBuilder::new(values.len(), precision, scale); + let scaling = 10.0_f64.powi(scale as i32); + + for maybe_value in values.iter() { + match maybe_value { + Some(v) => { + let v_as_float = v.to_f64().ok_or_else(|| { + ArrowError::ComputeError(format!("Expected float but got {:?}", v)) + })?; + builder.append_value((v_as_float * scaling).round() as i128)? + } + None => builder.append_null()?, + }; + } + Ok(Arc::new(builder.finish())) +} + #[cfg(test)] mod tests { use super::*; @@ -1571,6 +1653,63 @@ mod tests { assert!(c.is_null(2)); } + #[test] + fn test_cast_int32_to_decimal() { + let a = Int32Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Decimal(10, 2)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1000000, c.value(0)); + assert_eq!(1789000, c.value(1)); + } + + #[test] + fn test_cast_int64_to_decimal() { + let a = Int64Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Decimal(10, 2)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1000000, c.value(0)); + assert_eq!(1789000, c.value(1)); + } + + #[test] + fn test_cast_uint64_to_decimal() { + let a = UInt64Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Decimal(10, 2)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1000000, c.value(0)); + assert_eq!(1789000, c.value(1)); + } + + #[test] + fn test_cast_int64_to_decimal_exceeds_precision() { + let a = Int64Array::from(vec![10000, 17890]); + let array = Arc::new(a) as ArrayRef; + assert!(cast(&array, &DataType::Decimal(5, 2)).is_err()); + } + + #[test] + fn test_cast_f32_to_decimal() { + let a = Float32Array::from(vec![10_000.52, 17_890.499]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Decimal(10, 2)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1000052, c.value(0)); + assert_eq!(1789050, c.value(1)); + } + + #[test] + fn test_cast_f64_to_decimal() { + let a = Float64Array::from(vec![10_000.52, 17_890.499]); + let array = Arc::new(a) as ArrayRef; + let b = cast(&array, &DataType::Decimal(10, 2)).unwrap(); + let c = b.as_any().downcast_ref::().unwrap(); + assert_eq!(1000052, c.value(0)); + assert_eq!(1789050, c.value(1)); + } + #[test] fn test_cast_from_f64() { let f64_values: Vec = vec![