diff --git a/arrow-json/src/lib.rs b/arrow-json/src/lib.rs index ea0446c3d6b3..6d7ab4400b6e 100644 --- a/arrow-json/src/lib.rs +++ b/arrow-json/src/lib.rs @@ -75,7 +75,10 @@ pub mod reader; pub mod writer; pub use self::reader::{Reader, ReaderBuilder}; -pub use self::writer::{ArrayWriter, LineDelimitedWriter, Writer, WriterBuilder}; +pub use self::writer::{ + ArrayWriter, Encoder, EncoderFactory, EncoderOptions, LineDelimitedWriter, Writer, + WriterBuilder, +}; use half::f16; use serde_json::{Number, Value}; diff --git a/arrow-json/src/writer/encoder.rs b/arrow-json/src/writer/encoder.rs index 0b3c788d5519..ee6af03101f8 100644 --- a/arrow-json/src/writer/encoder.rs +++ b/arrow-json/src/writer/encoder.rs @@ -14,6 +14,8 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +use std::io::Write; +use std::sync::Arc; use crate::StructMode; use arrow_array::cast::AsArray; @@ -25,126 +27,322 @@ use arrow_schema::{ArrowError, DataType, FieldRef}; use half::f16; use lexical_core::FormattedSize; use serde::Serializer; -use std::io::Write; +/// Configuration options for the JSON encoder. #[derive(Debug, Clone, Default)] pub struct EncoderOptions { - pub explicit_nulls: bool, - pub struct_mode: StructMode, + /// Whether to include nulls in the output or elide them. + explicit_nulls: bool, + /// Whether to encode structs as JSON objects or JSON arrays of their values. + struct_mode: StructMode, + /// An optional hook for customizing encoding behavior. + encoder_factory: Option>, +} + +impl EncoderOptions { + /// Set whether to include nulls in the output or elide them. + pub fn with_explicit_nulls(mut self, explicit_nulls: bool) -> Self { + self.explicit_nulls = explicit_nulls; + self + } + + /// Set whether to encode structs as JSON objects or JSON arrays of their values. + pub fn with_struct_mode(mut self, struct_mode: StructMode) -> Self { + self.struct_mode = struct_mode; + self + } + + /// Set an optional hook for customizing encoding behavior. + pub fn with_encoder_factory(mut self, encoder_factory: Arc) -> Self { + self.encoder_factory = Some(encoder_factory); + self + } + + /// Get whether to include nulls in the output or elide them. + pub fn explicit_nulls(&self) -> bool { + self.explicit_nulls + } + + /// Get whether to encode structs as JSON objects or JSON arrays of their values. + pub fn struct_mode(&self) -> StructMode { + self.struct_mode + } + + /// Get the optional hook for customizing encoding behavior. + pub fn encoder_factory(&self) -> Option<&Arc> { + self.encoder_factory.as_ref() + } +} + +/// A trait to create custom encoders for specific data types. +/// +/// This allows overriding the default encoders for specific data types, +/// or adding new encoders for custom data types. +/// +/// # Examples +/// +/// ``` +/// use std::io::Write; +/// use arrow_array::{ArrayAccessor, Array, BinaryArray, Float64Array, RecordBatch}; +/// use arrow_array::cast::AsArray; +/// use arrow_schema::{DataType, Field, Schema, FieldRef}; +/// use arrow_json::{writer::{WriterBuilder, JsonArray, NullableEncoder}, StructMode}; +/// use arrow_json::{Encoder, EncoderFactory, EncoderOptions}; +/// use arrow_schema::ArrowError; +/// use std::sync::Arc; +/// use serde_json::json; +/// use serde_json::Value; +/// +/// struct IntArrayBinaryEncoder { +/// array: B, +/// } +/// +/// impl<'a, B> Encoder for IntArrayBinaryEncoder +/// where +/// B: ArrayAccessor, +/// { +/// fn encode(&mut self, idx: usize, out: &mut Vec) { +/// out.push(b'['); +/// let child = self.array.value(idx); +/// for (idx, byte) in child.iter().enumerate() { +/// write!(out, "{byte}").unwrap(); +/// if idx < child.len() - 1 { +/// out.push(b','); +/// } +/// } +/// out.push(b']'); +/// } +/// } +/// +/// #[derive(Debug)] +/// struct IntArayBinaryEncoderFactory; +/// +/// impl EncoderFactory for IntArayBinaryEncoderFactory { +/// fn make_default_encoder<'a>( +/// &self, +/// _field: &'a FieldRef, +/// array: &'a dyn Array, +/// _options: &'a EncoderOptions, +/// ) -> Result>, ArrowError> { +/// match array.data_type() { +/// DataType::Binary => { +/// let array = array.as_binary::(); +/// let encoder = IntArrayBinaryEncoder { array }; +/// let array_encoder = Box::new(encoder) as Box; +/// let nulls = array.nulls().cloned(); +/// Ok(Some(NullableEncoder::new(array_encoder, nulls))) +/// } +/// _ => Ok(None), +/// } +/// } +/// } +/// +/// let binary_array = BinaryArray::from_iter([Some(b"a".as_slice()), None, Some(b"b".as_slice())]); +/// let float_array = Float64Array::from(vec![Some(1.0), Some(2.3), None]); +/// let fields = vec![ +/// Field::new("bytes", DataType::Binary, true), +/// Field::new("float", DataType::Float64, true), +/// ]; +/// let batch = RecordBatch::try_new( +/// Arc::new(Schema::new(fields)), +/// vec![ +/// Arc::new(binary_array) as Arc, +/// Arc::new(float_array) as Arc, +/// ], +/// ) +/// .unwrap(); +/// +/// let json_value: Value = { +/// let mut buf = Vec::new(); +/// let mut writer = WriterBuilder::new() +/// .with_encoder_factory(Arc::new(IntArayBinaryEncoderFactory)) +/// .build::<_, JsonArray>(&mut buf); +/// writer.write_batches(&[&batch]).unwrap(); +/// writer.finish().unwrap(); +/// serde_json::from_slice(&buf).unwrap() +/// }; +/// +/// let expected = json!([ +/// {"bytes": [97], "float": 1.0}, +/// {"float": 2.3}, +/// {"bytes": [98]}, +/// ]); +/// +/// assert_eq!(json_value, expected); +/// ``` +pub trait EncoderFactory: std::fmt::Debug + Send + Sync { + /// Make an encoder that overrides the default encoder for a specific field and array or provides an encoder for a custom data type. + /// This can be used to override how e.g. binary data is encoded so that it is an encoded string or an array of integers. + /// + /// Note that the type of the field may not match the type of the array: for dictionary arrays unless the top-level dictionary is handled this + /// will be called again for the keys and values of the dictionary, at which point the field type will still be the outer dictionary type but the + /// array will have a different type. + /// For example, `field`` might have the type `Dictionary(i32, Utf8)` but `array` will be `Utf8`. + fn make_default_encoder<'a>( + &self, + _field: &'a FieldRef, + _array: &'a dyn Array, + _options: &'a EncoderOptions, + ) -> Result>, ArrowError> { + Ok(None) + } +} + +/// An encoder + a null buffer. +/// This is packaged together into a wrapper struct to minimize dynamic dispatch for null checks. +pub struct NullableEncoder<'a> { + encoder: Box, + nulls: Option, +} + +impl<'a> NullableEncoder<'a> { + /// Create a new encoder with a null buffer. + pub fn new(encoder: Box, nulls: Option) -> Self { + Self { encoder, nulls } + } + + /// Encode the value at index `idx` to `out`. + pub fn encode(&mut self, idx: usize, out: &mut Vec) { + self.encoder.encode(idx, out) + } + + /// Returns whether the value at index `idx` is null. + pub fn is_null(&self, idx: usize) -> bool { + self.nulls.as_ref().is_some_and(|nulls| nulls.is_null(idx)) + } + + /// Returns whether the encoder has any nulls. + pub fn has_nulls(&self) -> bool { + match self.nulls { + Some(ref nulls) => nulls.null_count() > 0, + None => false, + } + } +} + +impl Encoder for NullableEncoder<'_> { + fn encode(&mut self, idx: usize, out: &mut Vec) { + self.encoder.encode(idx, out) + } } /// A trait to format array values as JSON values /// /// Nullability is handled by the caller to allow encoding nulls implicitly, i.e. `{}` instead of `{"a": null}` pub trait Encoder { - /// Encode the non-null value at index `idx` to `out` + /// Encode the non-null value at index `idx` to `out`. /// - /// The behaviour is unspecified if `idx` corresponds to a null index + /// The behaviour is unspecified if `idx` corresponds to a null index. fn encode(&mut self, idx: usize, out: &mut Vec); } +/// Creates an encoder for the given array and field. +/// +/// This first calls the EncoderFactory if one is provided, and then falls back to the default encoders. pub fn make_encoder<'a>( + field: &'a FieldRef, array: &'a dyn Array, - options: &EncoderOptions, -) -> Result, ArrowError> { - let (encoder, nulls) = make_encoder_impl(array, options)?; - assert!(nulls.is_none(), "root cannot be nullable"); - Ok(encoder) -} - -fn make_encoder_impl<'a>( - array: &'a dyn Array, - options: &EncoderOptions, -) -> Result<(Box, Option), ArrowError> { + options: &'a EncoderOptions, +) -> Result, ArrowError> { macro_rules! primitive_helper { ($t:ty) => {{ let array = array.as_primitive::<$t>(); let nulls = array.nulls().cloned(); - (Box::new(PrimitiveEncoder::new(array)) as _, nulls) + NullableEncoder::new(Box::new(PrimitiveEncoder::new(array)), nulls) }}; } - Ok(downcast_integer! { + if let Some(factory) = options.encoder_factory() { + if let Some(encoder) = factory.make_default_encoder(field, array, options)? { + return Ok(encoder); + } + } + + let nulls = array.nulls().cloned(); + let encoder = downcast_integer! { array.data_type() => (primitive_helper), DataType::Float16 => primitive_helper!(Float16Type), DataType::Float32 => primitive_helper!(Float32Type), DataType::Float64 => primitive_helper!(Float64Type), DataType::Boolean => { let array = array.as_boolean(); - (Box::new(BooleanEncoder(array)), array.nulls().cloned()) + NullableEncoder::new(Box::new(BooleanEncoder(array)), array.nulls().cloned()) } - DataType::Null => (Box::new(NullEncoder), array.logical_nulls()), + DataType::Null => NullableEncoder::new(Box::new(NullEncoder), array.logical_nulls()), DataType::Utf8 => { let array = array.as_string::(); - (Box::new(StringEncoder(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(StringEncoder(array)), array.nulls().cloned()) } DataType::LargeUtf8 => { let array = array.as_string::(); - (Box::new(StringEncoder(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(StringEncoder(array)), array.nulls().cloned()) } DataType::Utf8View => { let array = array.as_string_view(); - (Box::new(StringViewEncoder(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(StringViewEncoder(array)), array.nulls().cloned()) } DataType::List(_) => { let array = array.as_list::(); - (Box::new(ListEncoder::try_new(array, options)?) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(ListEncoder::try_new(field, array, options)?), array.nulls().cloned()) } DataType::LargeList(_) => { let array = array.as_list::(); - (Box::new(ListEncoder::try_new(array, options)?) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(ListEncoder::try_new(field, array, options)?), array.nulls().cloned()) } DataType::FixedSizeList(_, _) => { let array = array.as_fixed_size_list(); - (Box::new(FixedSizeListEncoder::try_new(array, options)?) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(FixedSizeListEncoder::try_new(field, array, options)?), array.nulls().cloned()) } DataType::Dictionary(_, _) => downcast_dictionary_array! { - array => (Box::new(DictionaryEncoder::try_new(array, options)?) as _, array.logical_nulls()), + array => { + NullableEncoder::new(Box::new(DictionaryEncoder::try_new(field, array, options)?), array.nulls().cloned()) + }, _ => unreachable!() } DataType::Map(_, _) => { let array = array.as_map(); - (Box::new(MapEncoder::try_new(array, options)?) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(MapEncoder::try_new(field, array, options)?), array.nulls().cloned()) } DataType::FixedSizeBinary(_) => { let array = array.as_fixed_size_binary(); - (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned()) } DataType::Binary => { let array: &BinaryArray = array.as_binary(); - (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(BinaryEncoder::new(array)), array.nulls().cloned()) } DataType::LargeBinary => { let array: &LargeBinaryArray = array.as_binary(); - (Box::new(BinaryEncoder::new(array)) as _, array.nulls().cloned()) + NullableEncoder::new(Box::new(BinaryEncoder::new(array)), array.nulls().cloned()) } DataType::Struct(fields) => { let array = array.as_struct(); let encoders = fields.iter().zip(array.columns()).map(|(field, array)| { - let (encoder, nulls) = make_encoder_impl(array, options)?; + let encoder = make_encoder(field, array, options)?; Ok(FieldEncoder{ field: field.clone(), - encoder, nulls + encoder, }) }).collect::, ArrowError>>()?; let encoder = StructArrayEncoder{ encoders, - explicit_nulls: options.explicit_nulls, - struct_mode: options.struct_mode, + explicit_nulls: options.explicit_nulls(), + struct_mode: options.struct_mode(), }; - (Box::new(encoder) as _, array.nulls().cloned()) + let nulls = array.nulls().cloned(); + NullableEncoder::new(Box::new(encoder) as Box, nulls) } DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => { let options = FormatOptions::new().with_display_error(true); - let formatter = ArrayFormatter::try_new(array, &options)?; - (Box::new(RawArrayFormatter(formatter)) as _, array.nulls().cloned()) + let formatter = JsonArrayFormatter::new(ArrayFormatter::try_new(array, &options)?); + NullableEncoder::new(Box::new(RawArrayFormatter(formatter)) as Box, nulls) } d => match d.is_temporal() { true => { @@ -154,11 +352,17 @@ fn make_encoder_impl<'a>( // may need to be revisited let options = FormatOptions::new().with_display_error(true); let formatter = ArrayFormatter::try_new(array, &options)?; - (Box::new(formatter) as _, array.nulls().cloned()) + let formatter = JsonArrayFormatter::new(formatter); + NullableEncoder::new(Box::new(formatter) as Box, nulls) } - false => return Err(ArrowError::InvalidArgumentError(format!("JSON Writer does not support data type: {d}"))), + false => return Err(ArrowError::JsonError(format!( + "Unsupported data type for JSON encoding: {:?}", + d + ))) } - }) + }; + + Ok(encoder) } fn encode_string(s: &str, out: &mut Vec) { @@ -168,8 +372,13 @@ fn encode_string(s: &str, out: &mut Vec) { struct FieldEncoder<'a> { field: FieldRef, - encoder: Box, - nulls: Option, + encoder: NullableEncoder<'a>, +} + +impl FieldEncoder<'_> { + fn is_null(&self, idx: usize) -> bool { + self.encoder.is_null(idx) + } } struct StructArrayEncoder<'a> { @@ -196,9 +405,10 @@ impl Encoder for StructArrayEncoder<'_> { let mut is_first = true; // Nulls can only be dropped in explicit mode let drop_nulls = (self.struct_mode == StructMode::ObjectOnly) && !self.explicit_nulls; - for field_encoder in &mut self.encoders { - let is_null = is_some_and(field_encoder.nulls.as_ref(), |n| n.is_null(idx)); - if drop_nulls && is_null { + + for field_encoder in self.encoders.iter_mut() { + let is_null = field_encoder.is_null(idx); + if is_null && drop_nulls { continue; } @@ -212,9 +422,10 @@ impl Encoder for StructArrayEncoder<'_> { out.push(b':'); } - match is_null { - true => out.extend_from_slice(b"null"), - false => field_encoder.encoder.encode(idx, out), + if is_null { + out.extend_from_slice(b"null"); + } else { + field_encoder.encoder.encode(idx, out); } } match self.struct_mode { @@ -339,20 +550,19 @@ impl Encoder for StringViewEncoder<'_> { struct ListEncoder<'a, O: OffsetSizeTrait> { offsets: OffsetBuffer, - nulls: Option, - encoder: Box, + encoder: NullableEncoder<'a>, } impl<'a, O: OffsetSizeTrait> ListEncoder<'a, O> { fn try_new( + field: &'a FieldRef, array: &'a GenericListArray, - options: &EncoderOptions, + options: &'a EncoderOptions, ) -> Result { - let (encoder, nulls) = make_encoder_impl(array.values().as_ref(), options)?; + let encoder = make_encoder(field, array.values().as_ref(), options)?; Ok(Self { offsets: array.offsets().clone(), encoder, - nulls, }) } } @@ -362,22 +572,25 @@ impl Encoder for ListEncoder<'_, O> { let end = self.offsets[idx + 1].as_usize(); let start = self.offsets[idx].as_usize(); out.push(b'['); - match self.nulls.as_ref() { - Some(n) => (start..end).for_each(|idx| { + + if self.encoder.has_nulls() { + for idx in start..end { if idx != start { out.push(b',') } - match n.is_null(idx) { - true => out.extend_from_slice(b"null"), - false => self.encoder.encode(idx, out), + if self.encoder.is_null(idx) { + out.extend_from_slice(b"null"); + } else { + self.encoder.encode(idx, out); } - }), - None => (start..end).for_each(|idx| { + } + } else { + for idx in start..end { if idx != start { out.push(b',') } self.encoder.encode(idx, out); - }), + } } out.push(b']'); } @@ -385,19 +598,18 @@ impl Encoder for ListEncoder<'_, O> { struct FixedSizeListEncoder<'a> { value_length: usize, - nulls: Option, - encoder: Box, + encoder: NullableEncoder<'a>, } impl<'a> FixedSizeListEncoder<'a> { fn try_new( + field: &'a FieldRef, array: &'a FixedSizeListArray, - options: &EncoderOptions, + options: &'a EncoderOptions, ) -> Result { - let (encoder, nulls) = make_encoder_impl(array.values().as_ref(), options)?; + let encoder = make_encoder(field, array.values().as_ref(), options)?; Ok(Self { encoder, - nulls, value_length: array.value_length().as_usize(), }) } @@ -408,23 +620,24 @@ impl Encoder for FixedSizeListEncoder<'_> { let start = idx * self.value_length; let end = start + self.value_length; out.push(b'['); - match self.nulls.as_ref() { - Some(n) => (start..end).for_each(|idx| { + if self.encoder.has_nulls() { + for idx in start..end { if idx != start { - out.push(b','); + out.push(b',') } - if n.is_null(idx) { + if self.encoder.is_null(idx) { out.extend_from_slice(b"null"); } else { self.encoder.encode(idx, out); } - }), - None => (start..end).for_each(|idx| { + } + } else { + for idx in start..end { if idx != start { - out.push(b','); + out.push(b',') } self.encoder.encode(idx, out); - }), + } } out.push(b']'); } @@ -432,15 +645,16 @@ impl Encoder for FixedSizeListEncoder<'_> { struct DictionaryEncoder<'a, K: ArrowDictionaryKeyType> { keys: ScalarBuffer, - encoder: Box, + encoder: NullableEncoder<'a>, } impl<'a, K: ArrowDictionaryKeyType> DictionaryEncoder<'a, K> { fn try_new( + field: &'a FieldRef, array: &'a DictionaryArray, - options: &EncoderOptions, + options: &'a EncoderOptions, ) -> Result { - let (encoder, _) = make_encoder_impl(array.values().as_ref(), options)?; + let encoder = make_encoder(field, array.values().as_ref(), options)?; Ok(Self { keys: array.keys().values().clone(), @@ -455,22 +669,33 @@ impl Encoder for DictionaryEncoder<'_, K> { } } -impl Encoder for ArrayFormatter<'_> { +/// A newtype wrapper around [`ArrayFormatter`] to keep our usage of it private and not implement `Encoder` for the public type +struct JsonArrayFormatter<'a> { + formatter: ArrayFormatter<'a>, +} + +impl<'a> JsonArrayFormatter<'a> { + fn new(formatter: ArrayFormatter<'a>) -> Self { + Self { formatter } + } +} + +impl Encoder for JsonArrayFormatter<'_> { fn encode(&mut self, idx: usize, out: &mut Vec) { out.push(b'"'); // Should be infallible // Note: We are making an assumption that the formatter does not produce characters that require escaping - let _ = write!(out, "{}", self.value(idx)); + let _ = write!(out, "{}", self.formatter.value(idx)); out.push(b'"') } } -/// A newtype wrapper around [`ArrayFormatter`] that skips surrounding the value with `"` -struct RawArrayFormatter<'a>(ArrayFormatter<'a>); +/// A newtype wrapper around [`JsonArrayFormatter`] that skips surrounding the value with `"` +struct RawArrayFormatter<'a>(JsonArrayFormatter<'a>); impl Encoder for RawArrayFormatter<'_> { fn encode(&mut self, idx: usize, out: &mut Vec) { - let _ = write!(out, "{}", self.0.value(idx)); + let _ = write!(out, "{}", self.0.formatter.value(idx)); } } @@ -484,14 +709,17 @@ impl Encoder for NullEncoder { struct MapEncoder<'a> { offsets: OffsetBuffer, - keys: Box, - values: Box, - value_nulls: Option, + keys: NullableEncoder<'a>, + values: NullableEncoder<'a>, explicit_nulls: bool, } impl<'a> MapEncoder<'a> { - fn try_new(array: &'a MapArray, options: &EncoderOptions) -> Result { + fn try_new( + field: &'a FieldRef, + array: &'a MapArray, + options: &'a EncoderOptions, + ) -> Result { let values = array.values(); let keys = array.keys(); @@ -502,11 +730,11 @@ impl<'a> MapEncoder<'a> { ))); } - let (keys, key_nulls) = make_encoder_impl(keys, options)?; - let (values, value_nulls) = make_encoder_impl(values, options)?; + let keys = make_encoder(field, keys, options)?; + let values = make_encoder(field, values, options)?; // We sanity check nulls as these are currently not enforced by MapArray (#1697) - if is_some_and(key_nulls, |x| x.null_count() != 0) { + if keys.has_nulls() { return Err(ArrowError::InvalidArgumentError( "Encountered nulls in MapArray keys".to_string(), )); @@ -522,8 +750,7 @@ impl<'a> MapEncoder<'a> { offsets: array.offsets().clone(), keys, values, - value_nulls, - explicit_nulls: options.explicit_nulls, + explicit_nulls: options.explicit_nulls(), }) } } @@ -536,8 +763,9 @@ impl Encoder for MapEncoder<'_> { let mut is_first = true; out.push(b'{'); + for idx in start..end { - let is_null = is_some_and(self.value_nulls.as_ref(), |n| n.is_null(idx)); + let is_null = self.values.is_null(idx); if is_null && !self.explicit_nulls { continue; } @@ -550,9 +778,10 @@ impl Encoder for MapEncoder<'_> { self.keys.encode(idx, out); out.push(b':'); - match is_null { - true => out.extend_from_slice(b"null"), - false => self.values.encode(idx, out), + if is_null { + out.extend_from_slice(b"null"); + } else { + self.values.encode(idx, out); } } out.push(b'}'); diff --git a/arrow-json/src/writer/mod.rs b/arrow-json/src/writer/mod.rs index 5d3e558480ca..ee1b5fabe538 100644 --- a/arrow-json/src/writer/mod.rs +++ b/arrow-json/src/writer/mod.rs @@ -106,13 +106,13 @@ //! ``` mod encoder; -use std::{fmt::Debug, io::Write}; +use std::{fmt::Debug, io::Write, sync::Arc}; use crate::StructMode; use arrow_array::*; use arrow_schema::*; -use encoder::{make_encoder, EncoderOptions}; +pub use encoder::{make_encoder, Encoder, EncoderFactory, EncoderOptions, NullableEncoder}; /// This trait defines how to format a sequence of JSON objects to a /// byte stream. @@ -225,7 +225,7 @@ impl WriterBuilder { /// Returns `true` if this writer is configured to keep keys with null values. pub fn explicit_nulls(&self) -> bool { - self.0.explicit_nulls + self.0.explicit_nulls() } /// Set whether to keep keys with null values, or to omit writing them. @@ -251,13 +251,13 @@ impl WriterBuilder { /// Default is to skip nulls (set to `false`). If `struct_mode == ListOnly`, /// nulls will be written explicitly regardless of this setting. pub fn with_explicit_nulls(mut self, explicit_nulls: bool) -> Self { - self.0.explicit_nulls = explicit_nulls; + self.0 = self.0.with_explicit_nulls(explicit_nulls); self } /// Returns if this writer is configured to write structs as JSON Objects or Arrays. pub fn struct_mode(&self) -> StructMode { - self.0.struct_mode + self.0.struct_mode() } /// Set the [`StructMode`] for the writer, which determines whether structs @@ -266,7 +266,16 @@ impl WriterBuilder { /// `ListOnly`, nulls will be written explicitly regardless of the /// `explicit_nulls` setting. pub fn with_struct_mode(mut self, struct_mode: StructMode) -> Self { - self.0.struct_mode = struct_mode; + self.0 = self.0.with_struct_mode(struct_mode); + self + } + + /// Set an encoder factory to use when creating encoders for writing JSON. + /// + /// This can be used to override how some types are encoded or to provide + /// a fallback for types that are not supported by the default encoder. + pub fn with_encoder_factory(mut self, factory: Arc) -> Self { + self.0 = self.0.with_encoder_factory(factory); self } @@ -351,8 +360,16 @@ where } let array = StructArray::from(batch.clone()); - let mut encoder = make_encoder(&array, &self.options)?; + let field = Arc::new(Field::new_struct( + "", + batch.schema().fields().clone(), + false, + )); + + let mut encoder = make_encoder(&field, &array, &self.options)?; + // Validate that the root is not nullable + assert!(!encoder.has_nulls(), "root cannot be nullable"); for idx in 0..batch.num_rows() { self.format.start_row(&mut buffer, is_first_row)?; is_first_row = false; @@ -419,15 +436,19 @@ where #[cfg(test)] mod tests { use core::str; + use std::collections::HashMap; use std::fs::{read_to_string, File}; use std::io::{BufReader, Seek}; use std::sync::Arc; + use arrow_array::cast::AsArray; use serde_json::{json, Value}; + use super::LineDelimited; + use super::{Encoder, WriterBuilder}; use arrow_array::builder::*; use arrow_array::types::*; - use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ToByteSlice}; + use arrow_buffer::{i256, Buffer, NullBuffer, OffsetBuffer, ScalarBuffer, ToByteSlice}; use arrow_data::ArrayData; use crate::reader::*; @@ -446,7 +467,7 @@ mod tests { .map(|s| (!s.is_empty()).then(|| serde_json::from_slice(s).unwrap())) .collect(); - assert_eq!(expected, actual); + assert_eq!(actual, expected); } #[test] @@ -1891,7 +1912,7 @@ mod tests { let json_str = str::from_utf8(&json).unwrap(); assert_eq!( json_str, - r#"[{"my_dict":"a"},{"my_dict":null},{"my_dict":null}]"# + r#"[{"my_dict":"a"},{"my_dict":null},{"my_dict":""}]"# ) } @@ -2036,4 +2057,414 @@ mod tests { } assert_json_eq(&buf, expected); } + + fn make_fallback_encoder_test_data() -> (RecordBatch, Arc) { + // Note: this is not intended to be an efficient implementation. + // Just a simple example to demonstrate how to implement a custom encoder. + #[derive(Debug)] + enum UnionValue { + Int32(i32), + String(String), + } + + #[derive(Debug)] + struct UnionEncoder { + array: Vec>, + } + + impl Encoder for UnionEncoder { + fn encode(&mut self, idx: usize, out: &mut Vec) { + match &self.array[idx] { + None => out.extend_from_slice(b"null"), + Some(UnionValue::Int32(v)) => out.extend_from_slice(v.to_string().as_bytes()), + Some(UnionValue::String(v)) => { + out.extend_from_slice(format!("\"{}\"", v).as_bytes()) + } + } + } + } + + #[derive(Debug)] + struct UnionEncoderFactory; + + impl EncoderFactory for UnionEncoderFactory { + fn make_default_encoder<'a>( + &self, + _field: &'a FieldRef, + array: &'a dyn Array, + _options: &'a EncoderOptions, + ) -> Result>, ArrowError> { + let data_type = array.data_type(); + let fields = match data_type { + DataType::Union(fields, UnionMode::Sparse) => fields, + _ => return Ok(None), + }; + // check that the fields are supported + let fields = fields.iter().map(|(_, f)| f).collect::>(); + for f in fields.iter() { + match f.data_type() { + DataType::Null => {} + DataType::Int32 => {} + DataType::Utf8 => {} + _ => return Ok(None), + } + } + let (_, type_ids, _, buffers) = array.as_union().clone().into_parts(); + let mut values = Vec::with_capacity(type_ids.len()); + for idx in 0..type_ids.len() { + let type_id = type_ids[idx]; + let field = &fields[type_id as usize]; + let value = match field.data_type() { + DataType::Null => None, + DataType::Int32 => Some(UnionValue::Int32( + buffers[type_id as usize] + .as_primitive::() + .value(idx), + )), + DataType::Utf8 => Some(UnionValue::String( + buffers[type_id as usize] + .as_string::() + .value(idx) + .to_string(), + )), + _ => unreachable!(), + }; + values.push(value); + } + let array_encoder = + Box::new(UnionEncoder { array: values }) as Box; + let nulls = array.nulls().cloned(); + Ok(Some(NullableEncoder::new(array_encoder, nulls))) + } + } + + let int_array = Int32Array::from(vec![Some(1), None, None]); + let string_array = StringArray::from(vec![None, Some("a"), None]); + let null_array = NullArray::new(3); + let type_ids = [0_i8, 1, 2].into_iter().collect::>(); + + let union_fields = [ + (0, Arc::new(Field::new("A", DataType::Int32, false))), + (1, Arc::new(Field::new("B", DataType::Utf8, false))), + (2, Arc::new(Field::new("C", DataType::Null, false))), + ] + .into_iter() + .collect::(); + + let children = vec![ + Arc::new(int_array) as Arc, + Arc::new(string_array), + Arc::new(null_array), + ]; + + let array = UnionArray::try_new(union_fields.clone(), type_ids, None, children).unwrap(); + + let float_array = Float64Array::from(vec![Some(1.0), None, Some(3.4)]); + + let fields = vec![ + Field::new( + "union", + DataType::Union(union_fields, UnionMode::Sparse), + true, + ), + Field::new("float", DataType::Float64, true), + ]; + + let batch = RecordBatch::try_new( + Arc::new(Schema::new(fields)), + vec![ + Arc::new(array) as Arc, + Arc::new(float_array) as Arc, + ], + ) + .unwrap(); + + (batch, Arc::new(UnionEncoderFactory)) + } + + #[test] + fn test_fallback_encoder_factory_line_delimited_implicit_nulls() { + let (batch, encoder_factory) = make_fallback_encoder_test_data(); + + let mut buf = Vec::new(); + { + let mut writer = WriterBuilder::new() + .with_encoder_factory(encoder_factory) + .with_explicit_nulls(false) + .build::<_, LineDelimited>(&mut buf); + writer.write_batches(&[&batch]).unwrap(); + writer.finish().unwrap(); + } + + println!("{}", str::from_utf8(&buf).unwrap()); + + assert_json_eq( + &buf, + r#"{"union":1,"float":1.0} +{"union":"a"} +{"union":null,"float":3.4} +"#, + ); + } + + #[test] + fn test_fallback_encoder_factory_line_delimited_explicit_nulls() { + let (batch, encoder_factory) = make_fallback_encoder_test_data(); + + let mut buf = Vec::new(); + { + let mut writer = WriterBuilder::new() + .with_encoder_factory(encoder_factory) + .with_explicit_nulls(true) + .build::<_, LineDelimited>(&mut buf); + writer.write_batches(&[&batch]).unwrap(); + writer.finish().unwrap(); + } + + assert_json_eq( + &buf, + r#"{"union":1,"float":1.0} +{"union":"a","float":null} +{"union":null,"float":3.4} +"#, + ); + } + + #[test] + fn test_fallback_encoder_factory_array_implicit_nulls() { + let (batch, encoder_factory) = make_fallback_encoder_test_data(); + + let json_value: Value = { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_encoder_factory(encoder_factory) + .build::<_, JsonArray>(&mut buf); + writer.write_batches(&[&batch]).unwrap(); + writer.finish().unwrap(); + serde_json::from_slice(&buf).unwrap() + }; + + let expected = json!([ + {"union":1,"float":1.0}, + {"union":"a"}, + {"float":3.4,"union":null}, + ]); + + assert_eq!(json_value, expected); + } + + #[test] + fn test_fallback_encoder_factory_array_explicit_nulls() { + let (batch, encoder_factory) = make_fallback_encoder_test_data(); + + let json_value: Value = { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_encoder_factory(encoder_factory) + .with_explicit_nulls(true) + .build::<_, JsonArray>(&mut buf); + writer.write_batches(&[&batch]).unwrap(); + writer.finish().unwrap(); + serde_json::from_slice(&buf).unwrap() + }; + + let expected = json!([ + {"union":1,"float":1.0}, + {"union":"a", "float": null}, + {"union":null,"float":3.4}, + ]); + + assert_eq!(json_value, expected); + } + + #[test] + fn test_default_encoder_byte_array() { + struct IntArrayBinaryEncoder { + array: B, + } + + impl<'a, B> Encoder for IntArrayBinaryEncoder + where + B: ArrayAccessor, + { + fn encode(&mut self, idx: usize, out: &mut Vec) { + out.push(b'['); + let child = self.array.value(idx); + for (idx, byte) in child.iter().enumerate() { + write!(out, "{byte}").unwrap(); + if idx < child.len() - 1 { + out.push(b','); + } + } + out.push(b']'); + } + } + + #[derive(Debug)] + struct IntArayBinaryEncoderFactory; + + impl EncoderFactory for IntArayBinaryEncoderFactory { + fn make_default_encoder<'a>( + &self, + _field: &'a FieldRef, + array: &'a dyn Array, + _options: &'a EncoderOptions, + ) -> Result>, ArrowError> { + match array.data_type() { + DataType::Binary => { + let array = array.as_binary::(); + let encoder = IntArrayBinaryEncoder { array }; + let array_encoder = Box::new(encoder) as Box; + let nulls = array.nulls().cloned(); + Ok(Some(NullableEncoder::new(array_encoder, nulls))) + } + _ => Ok(None), + } + } + } + + let binary_array = BinaryArray::from_opt_vec(vec![Some(b"a"), None, Some(b"b")]); + let float_array = Float64Array::from(vec![Some(1.0), Some(2.3), None]); + let fields = vec![ + Field::new("bytes", DataType::Binary, true), + Field::new("float", DataType::Float64, true), + ]; + let batch = RecordBatch::try_new( + Arc::new(Schema::new(fields)), + vec![ + Arc::new(binary_array) as Arc, + Arc::new(float_array) as Arc, + ], + ) + .unwrap(); + + let json_value: Value = { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_encoder_factory(Arc::new(IntArayBinaryEncoderFactory)) + .build::<_, JsonArray>(&mut buf); + writer.write_batches(&[&batch]).unwrap(); + writer.finish().unwrap(); + serde_json::from_slice(&buf).unwrap() + }; + + let expected = json!([ + {"bytes": [97], "float": 1.0}, + {"float": 2.3}, + {"bytes": [98]}, + ]); + + assert_eq!(json_value, expected); + } + + #[test] + fn test_encoder_factory_customize_dictionary() { + // Test that we can customize the encoding of T even when it shows up as Dictionary<_, T>. + + // No particular reason to choose this example. + // Just trying to add some variety to the test cases and demonstrate use cases of the encoder factory. + struct PaddedInt32Encoder { + array: Int32Array, + } + + impl Encoder for PaddedInt32Encoder { + fn encode(&mut self, idx: usize, out: &mut Vec) { + let value = self.array.value(idx); + write!(out, "\"{value:0>8}\"").unwrap(); + } + } + + #[derive(Debug)] + struct CustomEncoderFactory; + + impl EncoderFactory for CustomEncoderFactory { + fn make_default_encoder<'a>( + &self, + field: &'a FieldRef, + array: &'a dyn Array, + _options: &'a EncoderOptions, + ) -> Result>, ArrowError> { + // The point here is: + // 1. You can use information from Field to determine how to do the encoding. + // 2. For dictionary arrays the Field is always the outer field but the array may be the keys or values array + // and thus the data type of `field` may not match the data type of `array`. + let padded = field + .metadata() + .get("padded") + .map(|v| v == "true") + .unwrap_or_default(); + match (array.data_type(), padded) { + (DataType::Int32, true) => { + let array = array.as_primitive::(); + let nulls = array.nulls().cloned(); + let encoder = PaddedInt32Encoder { + array: array.clone(), + }; + let array_encoder = Box::new(encoder) as Box; + Ok(Some(NullableEncoder::new(array_encoder, nulls))) + } + _ => Ok(None), + } + } + } + + let to_json = |batch| { + let mut buf = Vec::new(); + let mut writer = WriterBuilder::new() + .with_encoder_factory(Arc::new(CustomEncoderFactory)) + .build::<_, JsonArray>(&mut buf); + writer.write_batches(&[batch]).unwrap(); + writer.finish().unwrap(); + serde_json::from_slice::(&buf).unwrap() + }; + + // Control case: no dictionary wrapping works as expected. + let array = Int32Array::from(vec![Some(1), None, Some(2)]); + let field = Arc::new(Field::new("int", DataType::Int32, true).with_metadata( + HashMap::from_iter(vec![("padded".to_string(), "true".to_string())]), + )); + let batch = RecordBatch::try_new( + Arc::new(Schema::new(vec![field.clone()])), + vec![Arc::new(array)], + ) + .unwrap(); + + let json_value = to_json(&batch); + + let expected = json!([ + {"int": "00000001"}, + {}, + {"int": "00000002"}, + ]); + + assert_eq!(json_value, expected); + + // Now make a dictionary batch + let mut array_builder = PrimitiveDictionaryBuilder::::new(); + array_builder.append_value(1); + array_builder.append_null(); + array_builder.append_value(1); + let array = array_builder.finish(); + let field = Field::new( + "int", + DataType::Dictionary(Box::new(DataType::UInt16), Box::new(DataType::Int32)), + true, + ) + .with_metadata(HashMap::from_iter(vec![( + "padded".to_string(), + "true".to_string(), + )])); + let batch = RecordBatch::try_new(Arc::new(Schema::new(vec![field])), vec![Arc::new(array)]) + .unwrap(); + + let json_value = to_json(&batch); + + let expected = json!([ + {"int": "00000001"}, + {}, + {"int": "00000001"}, + ]); + + assert_eq!(json_value, expected); + } }