Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/src/arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ pub fn bfloat16_array<'py>(
values: Vec<Option<f32>>,
py: Python<'py>,
) -> PyResult<Bound<'py, PyAny>> {
let array = BFloat16Array::from_iter(values.into_iter().map(|v| v.map(bf16::from_f32)));
let array =
BFloat16Array::from_iter(values.into_iter().map(|v| v.map(bf16::from_f32))).into_inner();

// Create a record batch with a single column and an annotated schema
let field = Field::new("bfloat16", DataType::FixedSizeBinary(2), true).with_metadata(
Expand Down
129 changes: 58 additions & 71 deletions rust/lance-arrow/src/bfloat16.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@
use std::fmt::Formatter;
use std::slice;

use arrow_array::{
builder::BooleanBufferBuilder, iterator::ArrayIter, Array, ArrayAccessor, ArrayRef,
FixedSizeBinaryArray,
};
use arrow_array::{builder::BooleanBufferBuilder, Array, FixedSizeBinaryArray};
use arrow_buffer::MutableBuffer;
use arrow_data::ArrayData;
use arrow_schema::{ArrowError, DataType, Field as ArrowField};
Expand Down Expand Up @@ -41,9 +38,7 @@ pub struct BFloat16Type {}

/// An array of bfloat16 values
///
/// This implements the [`Array`] trait for bfloat16 values. Note that
/// bfloat16 is not the same thing as fp16 which is supported natively
/// by arrow-rs.
/// Note that bfloat16 is not the same thing as fp16 which is supported natively by arrow-rs.
#[derive(Clone)]
pub struct BFloat16Array {
inner: FixedSizeBinaryArray,
Expand Down Expand Up @@ -72,8 +67,27 @@ impl BFloat16Array {
values.into()
}

pub fn len(&self) -> usize {
self.inner.len()
}

pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}

pub fn is_null(&self, i: usize) -> bool {
self.inner.is_null(i)
}

pub fn null_count(&self) -> usize {
self.inner.null_count()
}

pub fn iter(&self) -> BFloat16Iter<'_> {
BFloat16Iter::new(self)
BFloat16Iter {
array: self,
index: 0,
}
}

pub fn value(&self, i: usize) -> bf16 {
Expand All @@ -100,65 +114,6 @@ impl BFloat16Array {
}
}

impl ArrayAccessor for &BFloat16Array {
type Item = bf16;

fn value(&self, index: usize) -> Self::Item {
BFloat16Array::value(self, index)
}

unsafe fn value_unchecked(&self, index: usize) -> Self::Item {
BFloat16Array::value_unchecked(self, index)
}
}

impl Array for BFloat16Array {
fn as_any(&self) -> &dyn std::any::Any {
self.inner.as_any()
}

fn to_data(&self) -> arrow_data::ArrayData {
self.inner.to_data()
}

fn into_data(self) -> arrow_data::ArrayData {
self.inner.into_data()
}

fn slice(&self, offset: usize, length: usize) -> ArrayRef {
let inner_array: &dyn Array = &self.inner;
inner_array.slice(offset, length)
}

fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> {
self.inner.nulls()
}

fn data_type(&self) -> &DataType {
self.inner.data_type()
}

fn len(&self) -> usize {
self.inner.len()
}

fn is_empty(&self) -> bool {
self.inner.is_empty()
}

fn offset(&self) -> usize {
self.inner.offset()
}

fn get_array_memory_size(&self) -> usize {
self.inner.get_array_memory_size()
}

fn get_buffer_memory_size(&self) -> usize {
self.inner.get_buffer_memory_size()
}
}

impl FromIterator<Option<bf16>> for BFloat16Array {
fn from_iter<I: IntoIterator<Item = Option<bf16>>>(iter: I) -> Self {
let mut buffer = MutableBuffer::new(10);
Expand Down Expand Up @@ -242,7 +197,27 @@ impl PartialEq<Self> for BFloat16Array {
}
}

type BFloat16Iter<'a> = ArrayIter<&'a BFloat16Array>;
pub struct BFloat16Iter<'a> {
array: &'a BFloat16Array,
index: usize,
}

impl<'a> Iterator for BFloat16Iter<'a> {
type Item = Option<bf16>;

fn next(&mut self) -> Option<Self::Item> {
if self.index >= self.array.len() {
return None;
}
let i = self.index;
self.index += 1;
if self.array.is_null(i) {
Some(None)
} else {
Some(Some(self.array.value(i)))
}
}
}

/// Methods that are lifted from arrow-rs temporarily until they are made public.
mod from_arrow {
Expand Down Expand Up @@ -290,17 +265,26 @@ mod from_arrow {
}
}

impl FloatArray<BFloat16Type> for BFloat16Array {
impl FloatArray<BFloat16Type> for FixedSizeBinaryArray {
type FloatType = BFloat16Type;

fn as_slice(&self) -> &[bf16] {
assert_eq!(
self.value_length(),
2,
"BFloat16 arrays must use FixedSizeBinary(2) storage"
);
unsafe {
slice::from_raw_parts(
self.inner.value_data().as_ptr() as *const bf16,
self.inner.value_data().len() / 2,
self.value_data().as_ptr() as *const bf16,
self.value_data().len() / 2,
)
}
}

fn from_values(values: Vec<bf16>) -> Self {
BFloat16Array::from(values).into_inner()
}
}

#[cfg(test)]
Expand All @@ -327,6 +311,9 @@ mod tests {
for (expected, value) in values.as_slice().iter().zip(array2.iter()) {
assert_eq!(Some(*expected), value);
}

let arrow_array = array.into_inner();
assert_eq!(arrow_array.as_slice(), values.as_slice());
}

#[test]
Expand Down
60 changes: 51 additions & 9 deletions rust/lance-arrow/src/floats.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use std::{

use arrow_array::{
types::{Float16Type, Float32Type, Float64Type},
Array, Float16Array, Float32Array, Float64Array,
Array, FixedSizeBinaryArray, Float16Array, Float32Array, Float64Array,
};
use arrow_schema::{DataType, Field};
use half::{bf16, f16};
Expand Down Expand Up @@ -95,7 +95,7 @@ pub trait ArrowFloatType: Debug {

/// Returns empty array of this type.
fn empty_array() -> Self::ArrayType {
Vec::<Self::Native>::new().into()
<Self::ArrayType as FloatArray<Self>>::from_values(Vec::new())
}
}

Expand Down Expand Up @@ -143,7 +143,7 @@ impl ArrowFloatType for BFloat16Type {
const MIN: Self::Native = bf16::MIN;
const MAX: Self::Native = bf16::MAX;

type ArrayType = BFloat16Array;
type ArrayType = FixedSizeBinaryArray;
}

impl ArrowFloatType for Float16Type {
Expand Down Expand Up @@ -180,13 +180,22 @@ impl ArrowFloatType for Float64Type {
///
/// This is similar to [`arrow_array::PrimitiveArray`] but applies to all float types (including bfloat16)
/// and is implemented as a trait and not a struct
pub trait FloatArray<T: ArrowFloatType + ?Sized>:
Array + Clone + From<Vec<T::Native>> + 'static
{
pub trait FloatArray<T: ArrowFloatType + ?Sized>: Array + Clone + 'static {
type FloatType: ArrowFloatType;

/// Returns a reference to the underlying data as a slice.
fn as_slice(&self) -> &[T::Native];

/// Construct an array from a vector of values.
fn from_values(values: Vec<T::Native>) -> Self;

/// Construct an array from an iterator of values.
fn from_iter_values(values: impl IntoIterator<Item = T::Native>) -> Self
where
Self: Sized,
{
Self::from_values(values.into_iter().collect())
}
}

impl FloatArray<Float16Type> for Float16Array {
Expand All @@ -195,6 +204,10 @@ impl FloatArray<Float16Type> for Float16Array {
fn as_slice(&self) -> &[<Float16Type as ArrowFloatType>::Native] {
self.values()
}

fn from_values(values: Vec<<Float16Type as ArrowFloatType>::Native>) -> Self {
Self::from(values)
}
}

impl FloatArray<Float32Type> for Float32Array {
Expand All @@ -203,6 +216,10 @@ impl FloatArray<Float32Type> for Float32Array {
fn as_slice(&self) -> &[<Float32Type as ArrowFloatType>::Native] {
self.values()
}

fn from_values(values: Vec<<Float32Type as ArrowFloatType>::Native>) -> Self {
Self::from(values)
}
}

impl FloatArray<Float64Type> for Float64Array {
Expand All @@ -211,6 +228,10 @@ impl FloatArray<Float64Type> for Float64Array {
fn as_slice(&self) -> &[<Float64Type as ArrowFloatType>::Native] {
self.values()
}

fn from_values(values: Vec<<Float64Type as ArrowFloatType>::Native>) -> Self {
Self::from(values)
}
}

/// Convert a float32 array to another float array
Expand All @@ -219,9 +240,10 @@ impl FloatArray<Float64Type> for Float64Array {
/// and need to be converted to the appropriate float type for the index.
pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Result<Arc<dyn Array>> {
match float_type {
FloatType::BFloat16 => Ok(Arc::new(BFloat16Array::from_iter_values(
input.values().iter().map(|v| bf16::from_f32(*v)),
))),
FloatType::BFloat16 => Ok(Arc::new(
BFloat16Array::from_iter_values(input.values().iter().map(|v| bf16::from_f32(*v)))
.into_inner(),
)),
FloatType::Float16 => Ok(Arc::new(Float16Array::from_iter_values(
input.values().iter().map(|v| f16::from_f32(*v)),
))),
Expand All @@ -231,3 +253,23 @@ pub fn coerce_float_vector(input: &Float32Array, float_type: FloatType) -> Resul
))),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_coerce_float_vector_bfloat16() {
let input = Float32Array::from(vec![1.0f32, 2.0, 3.0]);
let array = coerce_float_vector(&input, FloatType::BFloat16).unwrap();

assert_eq!(array.data_type(), &DataType::FixedSizeBinary(2));

let fixed = array
.as_any()
.downcast_ref::<FixedSizeBinaryArray>()
.unwrap();
let expected: Vec<bf16> = input.values().iter().map(|v| bf16::from_f32(*v)).collect();
assert_eq!(fixed.as_slice(), expected.as_slice());
}
}
Loading
Loading