Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 69 additions & 13 deletions rust/lance-encoding/src/encodings/logical/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@

use std::{ops::Range, sync::Arc};

use arrow_array::{Array, ArrayRef, MapArray};
use arrow_array::{Array, ArrayRef, ListArray, MapArray};
use arrow_schema::DataType;
use futures::future::BoxFuture;
use lance_arrow::deepcopy::deep_copy_nulls;
use lance_arrow::list::ListArrayExt;
use lance_core::{Error, Result};
use snafu::location;

Expand Down Expand Up @@ -53,22 +54,23 @@ impl FieldEncoder for MapStructuralEncoder {
.downcast_ref::<MapArray>()
.expect("MapEncoder used for non-map data");

// Map internally has offsets and entries (struct array)
let entries = map_array.entries();
let offsets = map_array.offsets();

// Add offsets to RepDefBuilder to handle nullability and list structure
if self.keep_original_array {
repdef.add_offsets(offsets.clone(), array.nulls().cloned())
let has_garbage_values = if self.keep_original_array {
repdef.add_offsets(map_array.offsets().clone(), array.nulls().cloned())
} else {
repdef.add_offsets(map_array.offsets().clone(), deep_copy_nulls(array.nulls()))
};

// MapArray is physically a ListArray, so convert and use ListArrayExt
let list_array: ListArray = map_array.clone().into();
let entries = if has_garbage_values {
list_array.filter_garbage_nulls().trimmed_values()
} else {
repdef.add_offsets(offsets.clone(), deep_copy_nulls(array.nulls()))
list_array.trimmed_values()
};

// Pass the entries (struct array) to the child encoder
// Convert to Arc<dyn Array>
let entries_arc: ArrayRef = Arc::new(entries.clone());
self.child
.maybe_encode(entries_arc, external_buffers, repdef, row_number, num_rows)
.maybe_encode(entries, external_buffers, repdef, row_number, num_rows)
}

fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result<Vec<EncodeTask>> {
Expand Down Expand Up @@ -240,7 +242,7 @@ mod tests {
builder::{Int32Builder, MapBuilder, StringBuilder},
Array, Int32Array, MapArray, StringArray, StructArray,
};
use arrow_buffer::{OffsetBuffer, ScalarBuffer};
use arrow_buffer::{NullBuffer, OffsetBuffer, ScalarBuffer};
use arrow_schema::{DataType, Field, Fields};

use crate::encoder::{default_encoding_strategy, ColumnIndexSequence, EncodingOptions};
Expand Down Expand Up @@ -410,6 +412,60 @@ mod tests {
.await;
}

#[test_log::test(tokio::test)]
async fn test_map_in_nullable_struct() {
// Test Struct<Map> where null struct rows have garbage map entries.
// The encoder must filter these garbage entries before encoding.
let entries_fields = Fields::from(vec![
Field::new("keys", DataType::Utf8, false),
Field::new("values", DataType::Int32, true),
]);
let entries_field = Arc::new(Field::new(
"entries",
DataType::Struct(entries_fields.clone()),
false,
));
let map_entries = StructArray::new(
entries_fields,
vec![
Arc::new(StringArray::from(vec!["a", "garbage", "b"])),
Arc::new(Int32Array::from(vec![1, 999, 2])),
],
None,
);
// map0: {"a": 1}, map1 (garbage): {"garbage": 999}, map2: {"b": 2}
let map_array: Arc<dyn Array> = Arc::new(MapArray::new(
entries_field,
OffsetBuffer::new(ScalarBuffer::from(vec![0, 1, 2, 3])),
map_entries,
None, // No nulls at map level - nulls come from struct
false,
));

let struct_array = StructArray::new(
Fields::from(vec![
Field::new("id", DataType::Int32, true),
Field::new("props", map_array.data_type().clone(), true),
]),
vec![
Arc::new(Int32Array::from(vec![Some(1), Some(2), Some(3)])),
map_array,
],
Some(NullBuffer::from(vec![true, false, true])), // Middle row is null
);

let test_cases = TestCases::default()
.with_range(0..3)
.with_min_file_version(LanceFileVersion::V2_2);

check_round_trip_encoding_of_data(
vec![Arc::new(struct_array)],
&test_cases,
HashMap::new(),
)
.await;
}

#[test_log::test(tokio::test)]
async fn test_list_of_maps() {
// Test List<Map<String, Int32>>
Expand Down
Loading