diff --git a/python/python/tests/test_map_type.py b/python/python/tests/test_map_type.py new file mode 100644 index 00000000000..c7cf1f5614e --- /dev/null +++ b/python/python/tests/test_map_type.py @@ -0,0 +1,852 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +from pathlib import Path + +import lance +import pyarrow as pa +import pytest + + +def test_simple_map_write_read(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("properties", pa.map_(pa.string(), pa.int32())), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3], + "properties": [ + [("key1", 10), ("key2", 20)], + [("key3", 30)], + [("key4", 40), ("key5", 50), ("key6", 60)], + ], + }, + schema=schema, + ) + + # Write to Lance (requires v2.2+) + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Read and verify + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_map_with_nulls(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("properties", pa.map_(pa.string(), pa.int32())), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3, 4], + "properties": [ + [("key1", 10)], + None, # null map + [], # empty map + [("key2", 20), ("key3", 30)], + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_map_with_null_values(tmp_path: Path): + schema = pa.schema( + [pa.field("id", pa.int32()), pa.field("data", pa.map_(pa.string(), pa.int32()))] + ) + + # Create map with null values using simple notation + data = pa.table( + { + "id": [1, 2], + "data": [ + [("a", 1), ("b", None)], # Second value is null + [("c", 3), ("d", None)], # Fourth value is null + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_empty_maps(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("map_field", pa.map_(pa.string(), pa.string())), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3], + "map_field": [ + [("a", "apple")], + [], # empty map + [("b", "banana")], + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_nested_map_in_struct(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field( + "record", + pa.struct( + [ + pa.field("name", pa.string()), + pa.field("attributes", pa.map_(pa.string(), pa.string())), + ] + ), + ), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3], + "record": [ + {"name": "Alice", "attributes": [("city", "NYC"), ("age", "30")]}, + {"name": "Bob", "attributes": [("city", "LA")]}, + {"name": "Charlie", "attributes": None}, + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_list_of_maps(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("configs", pa.list_(pa.map_(pa.string(), pa.int32()))), + ] + ) + + data = pa.table( + { + "id": [1, 2], + "configs": [ + [ + [("a", 1), ("b", 2)], # first map + [("c", 3)], # second map + ], + [ + [("d", 4), ("e", 5)] # first map + ], + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_map_different_key_types(tmp_path: Path): + # Test Map + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("int_map", pa.map_(pa.int32(), pa.string())), + ] + ) + + data = pa.table( + { + "id": [1, 2], + "int_map": [[(1, "one"), (2, "two")], [(3, "three"), (4, "four")]], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_query_map_column(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("properties", pa.map_(pa.string(), pa.int32())), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3, 4], + "properties": [ + [("key1", 10), ("key2", 20)], + [("key3", 30)], + [("key4", 40)], + [("key5", 50)], + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Column selection (full read) + result = dataset.to_table(columns=["id"]) + assert result.schema.names == ["id"] + assert result.num_rows == 4 + + # Full read with Map column + result = dataset.to_table() + assert "properties" in result.schema.names + assert result.num_rows == 4 + + result = dataset.to_table(filter="id > 2") + assert result.num_rows == 2 + + +def test_map_value_types(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("string_map", pa.map_(pa.string(), pa.string())), + pa.field("float_map", pa.map_(pa.string(), pa.float64())), + pa.field("bool_map", pa.map_(pa.string(), pa.bool_())), + ] + ) + + data = pa.table( + { + "id": [1, 2], + "string_map": [[("a", "apple"), ("b", "banana")], [("c", "cherry")]], + "float_map": [[("x", 1.5), ("y", 2.5)], [("z", 3.5)]], + "bool_map": [[("flag1", True), ("flag2", False)], [("flag3", True)]], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_map_append_data(tmp_path: Path): + schema = pa.schema( + [pa.field("id", pa.int32()), pa.field("data", pa.map_(pa.string(), pa.int32()))] + ) + + # Initial data + data1 = pa.table({"id": [1, 2], "data": [[("a", 1)], [("b", 2)]]}, schema=schema) + + lance.write_dataset(data1, tmp_path, data_storage_version="2.2") + + # Append more data + data2 = pa.table({"id": [3, 4], "data": [[("c", 3)], [("d", 4)]]}, schema=schema) + + # Reopen dataset before appending + lance.write_dataset(data2, tmp_path, mode="append", data_storage_version="2.2") + + # Reopen and read + dataset_reopened = lance.dataset(tmp_path) + result = dataset_reopened.to_table() + assert result.num_rows == 4 + assert result["id"].to_pylist() == [1, 2, 3, 4] + + +def test_map_large_entries(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("big_map", pa.map_(pa.string(), pa.int32())), + ] + ) + + # Create a map with 100 entries + large_map = [(f"key{i}", i * 10) for i in range(100)] + + data = pa.table( + { + "id": [1, 2], + "big_map": [large_map, large_map[:50]], # Second map has 50 entries + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + assert result.schema == schema + assert result.equals(data) + + +def test_map_version_compatibility(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("map_field", pa.map_(pa.string(), pa.int32())), + ] + ) + + data = pa.table( + {"id": [1, 2], "map_field": [[("a", 1)], [("b", 2)]]}, schema=schema + ) + + # Writing with v2.2 should succeed + dataset = lance.write_dataset(data, tmp_path / "v22", data_storage_version="2.2") + result = dataset.to_table() + assert result.equals(data) + + # should raise an error for v2.1 + with pytest.raises(Exception) as exc_info: + lance.write_dataset(data, tmp_path / "v21", data_storage_version="2.1") + # Verify error message + error_msg = str(exc_info.value) + assert ( + "Map data type" in error_msg + or "not yet implemented" in error_msg.lower() + or "not supported" in error_msg.lower() + ) + + +def test_map_roundtrip_preservation(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("map1", pa.map_(pa.string(), pa.int32())), + pa.field("map2", pa.map_(pa.int32(), pa.string())), + ] + ) + + data = pa.table( + {"id": [1], "map1": [[("z", 1), ("a", 2)]], "map2": [[(1, "a"), (2, "b")]]}, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + result = dataset.to_table() + + # Verify Map types + map1_type = result.schema.field("map1").type + map2_type = result.schema.field("map2").type + + assert isinstance(map1_type, pa.MapType) + assert isinstance(map2_type, pa.MapType) + + # Verify data content + assert result["id"].to_pylist() == [1] + assert len(result["map1"][0]) == 2 + assert len(result["map2"][0]) == 2 + + +def test_map_keys_cannot_be_null(tmp_path: Path): + # Arrow Map spec requires keys to be non-nullable + # The key field in the entries struct must have nullable=False + + # Test 1: Valid map with non-nullable keys (default behavior) + schema_valid = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("valid_map", pa.map_(pa.string(), pa.int32())), + ] + ) + + data_valid = pa.table( + {"id": [1, 2], "valid_map": [[("a", 1), ("b", 2)], [("c", 3)]]}, + schema=schema_valid, + ) + + # This should succeed + dataset = lance.write_dataset( + data_valid, tmp_path / "valid", data_storage_version="2.2" + ) + result = dataset.to_table() + assert result.equals(data_valid) + + # Verify the key field is non-nullable in the schema + map_type = result.schema.field("valid_map").type + assert isinstance(map_type, pa.MapType) + + # Access the key and value types + assert map_type.key_type == pa.string() + assert map_type.item_type == pa.int32() + + # Test 2: Verify we can write maps with null values (but not null keys) + data_null_values = pa.table( + { + "id": [1, 2], + "map_with_null_values": [ + [("a", 1), ("b", None)], # null value is OK + [("c", None)], # null value is OK + ], + }, + schema=pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("map_with_null_values", pa.map_(pa.string(), pa.int32())), + ] + ), + ) + + dataset2 = lance.write_dataset( + data_null_values, tmp_path / "null_values", data_storage_version="2.2" + ) + result2 = dataset2.to_table() + + # Verify null values in map are preserved + assert result2["id"].to_pylist() == [1, 2] + map_data = result2["map_with_null_values"] + + # First map has 2 entries + first_map = map_data[0] + assert len(first_map) == 2 + + # Values can be null + values_list = [item[1] for item in first_map.as_py()] + assert None in values_list # At least one null value + + # Test 3: Verify we cannot write maps with null keys + with pytest.raises(Exception): + pa.table( + { + "id": [1, 2], + "null_key_map": [ + [(None, 1), ("b", 2)], # null key is not allowed + [("c", 3)], + ], + }, + schema=pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("null_key_map", pa.map_(pa.string(), pa.int32())), + ] + ), + ) + + +def test_map_projection_queries(tmp_path: Path): + # Create a dataset with multiple columns including Map types + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("name", pa.string()), + pa.field("properties", pa.map_(pa.string(), pa.int32())), + pa.field("tags", pa.map_(pa.string(), pa.string())), + pa.field("score", pa.float64()), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3, 4, 5], + "name": ["Alice", "Bob", "Charlie", "David", "Eve"], + "properties": [ + [("age", 25), ("height", 170)], + [("age", 30), ("weight", 75)], + [("age", 35)], + None, # null map + [("age", 28), ("height", 165), ("weight", 60)], + ], + "tags": [ + [("role", "admin"), ("status", "active")], + [("role", "user")], + [("status", "inactive")], + [("role", "guest")], + [("role", "user"), ("status", "active")], + ], + "score": [95.5, 87.3, 91.2, 78.9, 88.7], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Test 1: Project only map column + result1 = dataset.to_table(columns=["properties"]) + assert result1.num_rows == 5, "Row count mismatch for single map column projection" + assert result1.schema.names == ["properties"], "Schema names mismatch" + assert result1.schema.field("properties").type == pa.map_( + pa.string(), pa.int32() + ), "Map type mismatch" + # Verify data consistency + assert result1["properties"][0].as_py() == [("age", 25), ("height", 170)] + assert result1["properties"][3].as_py() is None # null map preserved + + # Test 2: Project multiple columns including map + result2 = dataset.to_table(columns=["id", "properties", "score"]) + assert result2.num_rows == 5, "Row count mismatch for multi-column projection" + assert result2.schema.names == ["id", "properties", "score"], ( + "Schema names mismatch" + ) + assert result2["id"].to_pylist() == [1, 2, 3, 4, 5], "ID data mismatch" + assert result2["score"].to_pylist() == [95.5, 87.3, 91.2, 78.9, 88.7], ( + "Score data mismatch" + ) + + # Test 3: Project two map columns + result3 = dataset.to_table(columns=["properties", "tags"]) + assert result3.num_rows == 5, "Row count mismatch for two map columns" + assert result3.schema.names == ["properties", "tags"], "Schema names mismatch" + assert isinstance(result3.schema.field("properties").type, pa.MapType) + assert isinstance(result3.schema.field("tags").type, pa.MapType) + # Verify both map columns have correct data + assert result3["tags"][0].as_py() == [("role", "admin"), ("status", "active")] + + # Test 4: Projection with filter + result4 = dataset.to_table(columns=["id", "name", "properties"], filter="id > 2") + assert result4.num_rows == 3, ( + "Row count mismatch with filter (expected 3 rows for id > 2)" + ) + assert result4.schema.names == ["id", "name", "properties"], ( + "Schema names mismatch with filter" + ) + assert result4["id"].to_pylist() == [3, 4, 5], "Filtered ID data mismatch" + assert result4["name"].to_pylist() == ["Charlie", "David", "Eve"], ( + "Filtered name data mismatch" + ) + # Verify map data is correct for filtered rows + assert result4["properties"][0].as_py() == [("age", 35)] # Charlie's properties + assert result4["properties"][1].as_py() is None # David's properties (null) + + # Test 5: Projection with more complex filter + result5 = dataset.to_table(columns=["id", "properties"], filter="score >= 90") + assert result5.num_rows == 2, ( + "Row count mismatch with score filter (expected 2 rows)" + ) + assert result5.schema.names == ["id", "properties"], ( + "Should only contain id and properties columns" + ) + assert result5["id"].to_pylist() == [1, 3], ( + "Filtered ID data mismatch for score >= 90" + ) + + # Test 6: Project all columns (no projection) + result6 = dataset.to_table() + assert result6.num_rows == 5, "Row count mismatch for full table read" + assert result6.schema == schema, "Full schema mismatch" + assert result6.equals(data), "Full data mismatch" + + # Test 7: Project only non-map columns + result7 = dataset.to_table(columns=["id", "name", "score"]) + assert result7.num_rows == 5, "Row count mismatch for non-map projection" + assert result7.schema.names == ["id", "name", "score"], ( + "Should only contain id, name and score columns" + ) + assert "properties" not in result7.schema.names, ( + "Map column should not be in result" + ) + assert "tags" not in result7.schema.names, "Map column should not be in result" + assert result7["name"].to_pylist() == ["Alice", "Bob", "Charlie", "David", "Eve"] + + +def test_map_projection_nested_struct(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field( + "user", + pa.struct( + [ + pa.field("name", pa.string()), + pa.field("metadata", pa.map_(pa.string(), pa.string())), + pa.field("age", pa.int32()), + ] + ), + ), + pa.field("extra", pa.string()), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3], + "user": [ + { + "name": "Alice", + "metadata": [("city", "NYC"), ("country", "USA")], + "age": 30, + }, + {"name": "Bob", "metadata": [("city", "LA")], "age": 25}, + {"name": "Charlie", "metadata": None, "age": 35}, + ], + "extra": ["info1", "info2", "info3"], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Test 1: Project the entire struct containing map + result1 = dataset.to_table(columns=["id", "user"]) + assert result1.num_rows == 3, "Row count mismatch" + assert result1.schema.names == ["id", "user"], "Schema names mismatch" + # Verify struct schema + user_type = result1.schema.field("user").type + assert isinstance(user_type, pa.StructType) + # Verify nested map type + metadata_field = user_type.field("metadata") + assert isinstance(metadata_field.type, pa.MapType) + # Verify data + assert result1["user"][0].as_py()["name"] == "Alice" + assert result1["user"][0].as_py()["metadata"] == [ + ("city", "NYC"), + ("country", "USA"), + ] + + # Test 2: Project struct with filter + result2 = dataset.to_table(columns=["user"], filter="id > 1") + assert result2.num_rows == 2, "Row count mismatch with filter" + assert result2.schema.names == ["user"], "Should only contain user column" + assert result2["user"][0].as_py()["name"] == "Bob" + assert result2["user"][1].as_py()["metadata"] is None # Charlie has null metadata + + # Test 3: Project only id and extra (not the struct with map) + result3 = dataset.to_table(columns=["id", "extra"]) + assert result3.num_rows == 3, "Row count mismatch" + assert result3.schema.names == ["id", "extra"], ( + "Should only contain id and extra columns" + ) + assert "user" not in result3.schema.names, "Struct column should not be in result" + assert result3["extra"].to_pylist() == ["info1", "info2", "info3"] + + +def test_map_projection_list_of_maps(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("configs", pa.list_(pa.map_(pa.string(), pa.int32()))), + pa.field("name", pa.string()), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3, 4], + "configs": [ + [[("port", 8080), ("timeout", 30)], [("port", 8081), ("retries", 3)]], + [[("port", 9090)]], + None, # null list + [[("port", 7070), ("timeout", 60)], [("retries", 5)], [("port", 7071)]], + ], + "name": ["service1", "service2", "service3", "service4"], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Test 1: Project list of maps + result1 = dataset.to_table(columns=["configs"]) + assert result1.num_rows == 4, "Row count mismatch" + assert result1.schema.names == ["configs"], "Should only contain configs column" + list_type = result1.schema.field("configs").type + assert isinstance(list_type, pa.ListType) + assert isinstance(list_type.value_type, pa.MapType) + # Verify data + assert len(result1["configs"][0]) == 2 # Two maps in first list + assert result1["configs"][2].as_py() is None # Null list + + # Test 2: Project with id and configs + result2 = dataset.to_table(columns=["id", "configs"]) + assert result2.num_rows == 4, "Row count mismatch" + assert result2.schema.names == ["id", "configs"], ( + "Should only contain id and configs columns" + ) + assert result2["id"].to_pylist() == [1, 2, 3, 4] + assert len(result2["configs"][3]) == 3 # Three maps in last list + + # Test 3: Projection with filter + result3 = dataset.to_table(columns=["id", "configs", "name"], filter="id <= 2") + assert result3.num_rows == 2, "Row count mismatch with filter" + assert result3.schema.names == ["id", "configs", "name"], ( + "Should only contain id, configs and name columns" + ) + assert result3["name"].to_pylist() == ["service1", "service2"] + # Verify the list of maps data for filtered rows + first_configs = result3["configs"][0].as_py() + assert len(first_configs) == 2 + assert first_configs[0] == [("port", 8080), ("timeout", 30)] + + +def test_map_projection_multiple_value_types(tmp_path: Path): + schema = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("int_map", pa.map_(pa.string(), pa.int32())), + pa.field("float_map", pa.map_(pa.string(), pa.float64())), + pa.field("string_map", pa.map_(pa.string(), pa.string())), + pa.field("bool_map", pa.map_(pa.string(), pa.bool_())), + ] + ) + + data = pa.table( + { + "id": [1, 2, 3], + "int_map": [[("a", 1), ("b", 2)], [("c", 3)], None], + "float_map": [[("x", 1.5), ("y", 2.5)], [("z", 3.5)], [("w", 4.5)]], + "string_map": [ + [("k1", "v1"), ("k2", "v2")], + [("k3", "v3")], + [("k4", "v4"), ("k5", "v5")], + ], + "bool_map": [ + [("flag1", True)], + [("flag2", False)], + [("flag3", True), ("flag4", False)], + ], + }, + schema=schema, + ) + + dataset = lance.write_dataset(data, tmp_path, data_storage_version="2.2") + + # Test 1: Project subset of map columns + result1 = dataset.to_table(columns=["id", "int_map", "string_map"]) + assert result1.num_rows == 3, "Row count mismatch" + assert result1.schema.names == ["id", "int_map", "string_map"] + assert result1.schema.field("int_map").type == pa.map_(pa.string(), pa.int32()) + assert result1.schema.field("string_map").type == pa.map_(pa.string(), pa.string()) + + # Test 2: Project all map columns (no id) + result2 = dataset.to_table( + columns=["int_map", "float_map", "string_map", "bool_map"] + ) + assert result2.num_rows == 3, "Row count mismatch" + assert len(result2.schema.names) == 4 + # Verify all are map types + for col in result2.schema.names: + assert isinstance(result2.schema.field(col).type, pa.MapType) + + # Test 3: Project single map column with filter + result3 = dataset.to_table(columns=["float_map"], filter="id != 2") + assert result3.num_rows == 2, "Row count mismatch with filter" + assert result3.schema.names == ["float_map"], "Should only contain float_map column" + assert result3["float_map"][0].as_py() == [("x", 1.5), ("y", 2.5)] + assert result3["float_map"][1].as_py() == [("w", 4.5)] + + # Test 4: Verify data consistency for all projections + result4 = dataset.to_table(columns=["id", "bool_map"]) + assert result4.num_rows == 3, "Row count mismatch" + assert result4.schema.names == ["id", "bool_map"], ( + "Should only contain id and bool_map columns" + ) + assert result4["bool_map"][0].as_py() == [("flag1", True)] + assert result4["bool_map"][1].as_py() == [("flag2", False)] + assert result4["bool_map"][2].as_py() == [("flag3", True), ("flag4", False)] + + +def test_map_keys_sorted_unsupported(tmp_path: Path): + """Test that keys_sorted=True is not supported""" + # Test that keys_sorted=True is rejected + schema_sorted = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field("sorted_map", pa.map_(pa.string(), pa.int32(), keys_sorted=True)), + ] + ) + + data_sorted = pa.table( + {"id": [1, 2], "sorted_map": [[("a", 1), ("b", 2)], [("c", 3)]]}, + schema=schema_sorted, + ) + + # Writing should fail with keys_sorted=True + with pytest.raises(Exception) as exc_info: + lance.write_dataset( + data_sorted, tmp_path / "sorted", data_storage_version="2.2" + ) + error_msg = str(exc_info.value) + assert ( + "keys_sorted=true" in error_msg.lower() + or "unsupported map field" in error_msg.lower() + ), f"Expected error about keys_sorted=true, got: {error_msg}" + + # Test that keys_sorted=False (default) is supported + schema_unsorted = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field( + "unsorted_map", pa.map_(pa.string(), pa.int32(), keys_sorted=False) + ), + ] + ) + + data_unsorted = pa.table( + {"id": [1, 2], "unsorted_map": [[("z", 1), ("a", 2)], [("c", 3)]]}, + schema=schema_unsorted, + ) + + dataset_unsorted = lance.write_dataset( + data_unsorted, tmp_path / "unsorted", data_storage_version="2.2" + ) + result_unsorted = dataset_unsorted.to_table() + + # Verify keys_sorted=False is preserved + map_type_unsorted = result_unsorted.schema.field("unsorted_map").type + assert isinstance(map_type_unsorted, pa.MapType) + assert map_type_unsorted.keys_sorted is False + + # Test that default (keys_sorted=False) works + schema_default = pa.schema( + [ + pa.field("id", pa.int32()), + pa.field( + "default_map", pa.map_(pa.string(), pa.int32()) + ), # default is False + ] + ) + + data_default = pa.table( + {"id": [1, 2], "default_map": [[("z", 1), ("a", 2)], [("c", 3)]]}, + schema=schema_default, + ) + + dataset_default = lance.write_dataset( + data_default, tmp_path / "default", data_storage_version="2.2" + ) + result_default = dataset_default.to_table() + + # Verify default keys_sorted=False is preserved + map_type_default = result_default.schema.field("default_map").type + assert isinstance(map_type_default, pa.MapType) + assert map_type_default.keys_sorted is False diff --git a/rust/lance-core/src/datatypes.rs b/rust/lance-core/src/datatypes.rs index f6400c5a5ae..a0a1fd1817a 100644 --- a/rust/lance-core/src/datatypes.rs +++ b/rust/lance-core/src/datatypes.rs @@ -98,6 +98,10 @@ impl LogicalType { fn is_blob(&self) -> bool { self.0 == BLOB_LOGICAL_TYPE } + + fn is_map(&self) -> bool { + self.0 == "map" + } } impl From<&str> for LogicalType { @@ -196,6 +200,21 @@ impl TryFrom<&DataType> for LogicalType { } } DataType::FixedSizeBinary(len) => format!("fixed_size_binary:{}", *len), + DataType::Map(_, keys_sorted) => { + // TODO: We only support keys_sorted=false for now, + // because converting a rust arrow map field to the python arrow field will + // lose the keys_sorted property. + if *keys_sorted { + return Err(Error::Schema { + message: format!( + "Unsupported map data type with keys_sorted=true: {:?}", + dt + ), + location: location!(), + }); + } + "map".to_string() + } _ => { return Err(Error::Schema { message: format!("Unsupported data type: {:?}", dt), diff --git a/rust/lance-core/src/datatypes/field.rs b/rust/lance-core/src/datatypes/field.rs index 4b42c3581ac..7f0bc0327db 100644 --- a/rust/lance-core/src/datatypes/field.rs +++ b/rust/lance-core/src/datatypes/field.rs @@ -168,6 +168,9 @@ impl Field { lt if lt.is_struct() => { DataType::Struct(self.children.iter().map(ArrowField::from).collect()) } + lt if lt.is_map() => { + DataType::Map(Arc::new(ArrowField::from(&self.children[0])), false) + } lt => DataType::try_from(lt).unwrap(), } } @@ -250,11 +253,17 @@ impl Field { } pub fn apply_projection(&self, projection: &Projection) -> Option { - let children = self - .children - .iter() - .filter_map(|c| c.apply_projection(projection)) - .collect::>(); + // For Map types, we must preserve ALL children (entries struct with key/value) + // Map internal structure should not be subject to projection filtering + let children = if self.logical_type.is_map() { + // Map field: keep all children intact (entries struct and its key/value fields) + self.children.clone() + } else { + self.children + .iter() + .filter_map(|c| c.apply_projection(projection)) + .collect::>() + }; // The following case is invalid: // - This is a nested field (has children) @@ -691,7 +700,8 @@ impl Field { Ok(cloned) } (DataType::List(_), DataType::List(_)) - | (DataType::LargeList(_), DataType::LargeList(_)) => { + | (DataType::LargeList(_), DataType::LargeList(_)) + | (DataType::Map(_, _), DataType::Map(_, _)) => { let projected = self.children[0].project_by_field(&other.children[0], on_type_mismatch)?; let mut cloned = self.clone(); @@ -769,7 +779,9 @@ impl Field { if matches!( (&self_type, &other_type), - (DataType::Struct(_), DataType::Struct(_)) | (DataType::List(_), DataType::List(_)) + (DataType::Struct(_), DataType::Struct(_)) + | (DataType::List(_), DataType::List(_)) + | (DataType::Map(_, _), DataType::Map(_, _)) ) { // Blob v2 uses a struct logical type for descriptors, which differs from the logical // input struct (data/uri). When intersecting schemas for projection we want to keep @@ -1023,6 +1035,7 @@ impl TryFrom<&ArrowField> for Field { type Error = Error; fn try_from(field: &ArrowField) -> Result { + let mut metadata = field.metadata().clone(); let children = match field.data_type() { DataType::Struct(children) => children .iter() @@ -1030,9 +1043,44 @@ impl TryFrom<&ArrowField> for Field { .collect::>()?, DataType::List(item) => vec![Self::try_from(item.as_ref())?], DataType::LargeList(item) => vec![Self::try_from(item.as_ref())?], + DataType::Map(entries, keys_sorted) => { + // TODO: We only support keys_sorted=false for now, + // because converting a rust arrow map field to the python arrow field will + // lose the keys_sorted property. + if *keys_sorted { + return Err(Error::Schema { + message: "Unsupported map field with keys_sorted=true".to_string(), + location: location!(), + }); + } + // Validate Map entries follow Arrow specification + let DataType::Struct(struct_fields) = entries.data_type() else { + return Err(Error::Schema { + message: "Map entries field must be a Struct".to_string(), + location: location!(), + }); + }; + if struct_fields.len() < 2 { + return Err(Error::Schema { + message: "Map entries struct must contain both key and value fields" + .to_string(), + location: location!(), + }); + } + let key_field = &struct_fields[0]; + if key_field.is_nullable() { + return Err(Error::Schema { + message: format!( + "Map key field '{}' must be non-nullable according to Arrow Map specification", + key_field.name() + ), + location: location!(), + }); + } + vec![Self::try_from(entries.as_ref())?] + } _ => vec![], }; - let mut metadata = field.metadata().clone(); let unenforced_primary_key = metadata .get(LANCE_UNENFORCED_PRIMARY_KEY) .map(|s| matches!(s.to_lowercase().as_str(), "true" | "1" | "yes")) @@ -1063,8 +1111,10 @@ impl TryFrom<&ArrowField> for Field { dt if dt.is_fixed_stride() => Some(Encoding::Plain), dt if dt.is_binary_like() => Some(Encoding::VarBinary), DataType::Dictionary(_, _) => Some(Encoding::Dictionary), - // Use plain encoder to store the offsets of list. - DataType::List(_) | DataType::LargeList(_) => Some(Encoding::Plain), + // Use plain encoder to store the offsets of list and map. + DataType::List(_) | DataType::LargeList(_) | DataType::Map(_, _) => { + Some(Encoding::Plain) + } _ => None, }, metadata, @@ -1206,6 +1256,23 @@ mod tests { .0, "struct" ); + + assert_eq!( + LogicalType::try_from(&DataType::Map( + Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + ArrowField::new("value", DataType::Int32, true), + ])), + true + )), + false + )) + .unwrap() + .0, + "map" + ); } #[test] @@ -1225,6 +1292,89 @@ mod tests { assert_eq!(ArrowField::from(&field), arrow_field); } + #[test] + fn map_key_must_be_non_nullable() { + let entries_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, true), // invalid: nullable key + ArrowField::new("value", DataType::Int32, true), + ])), + false, + )); + let arrow_field = ArrowField::new("props", DataType::Map(entries_field, false), true); + + let result = Field::try_from(&arrow_field); + assert!(result.is_err(), "Nullable map key should be rejected"); + } + + #[test] + fn map_keys_sorted_unsupported() { + let entries_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + ArrowField::new("value", DataType::Int32, true), + ])), + false, + )); + + // Test that keys_sorted=true is rejected + let arrow_field_sorted = ArrowField::new( + "map_field", + DataType::Map(entries_field.clone(), true), + true, + ); + let result = Field::try_from(&arrow_field_sorted); + assert!(result.is_err(), "keys_sorted=true should be rejected"); + assert!(result.unwrap_err().to_string().contains("keys_sorted=true")); + + // Test that keys_sorted=false is supported + let arrow_field_unsorted = + ArrowField::new("map_field", DataType::Map(entries_field, false), true); + let lance_field_unsorted = Field::try_from(&arrow_field_unsorted).unwrap(); + + // Verify conversion back to ArrowField preserves keys_sorted=false + let converted_field_unsorted = ArrowField::from(&lance_field_unsorted); + match converted_field_unsorted.data_type() { + DataType::Map(_, keys_sorted) => assert!(!keys_sorted, "keys_sorted should be false"), + _ => panic!("Expected Map type"), + } + } + + #[test] + fn map_entries_must_be_struct() { + let entries_field = Arc::new(ArrowField::new("entries", DataType::Utf8, false)); + let arrow_field = ArrowField::new("map_field", DataType::Map(entries_field, false), true); + + let err = Field::try_from(&arrow_field).unwrap_err(); + assert!( + err.to_string() + .contains("Map entries field must be a Struct"), + "Expected struct requirement error, got {err}" + ); + } + + #[test] + fn map_entries_struct_needs_key_and_value() { + let entries_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ArrowField::new( + "key", + DataType::Utf8, + false, + )])), + false, + )); + let arrow_field = ArrowField::new("map_field", DataType::Map(entries_field, false), true); + + let err = Field::try_from(&arrow_field).unwrap_err(); + assert!( + err.to_string().contains("must contain both key and value"), + "Expected both fields requirement error, got {err}" + ); + } + #[test] fn test_project_by_field_null_type() { let f1: Field = ArrowField::new("a", DataType::Null, true) diff --git a/rust/lance-core/src/datatypes/schema.rs b/rust/lance-core/src/datatypes/schema.rs index cdcc3cef1e6..808ea8b74e0 100644 --- a/rust/lance-core/src/datatypes/schema.rs +++ b/rust/lance-core/src/datatypes/schema.rs @@ -697,6 +697,16 @@ impl TryFrom<&ArrowSchema> for Schema { location: location!(), }); } + + if ancestor.logical_type.is_map() { + return Err(Error::Schema { + message: format!( + "Primary key column must not be in a map type: {}", + ancestor + ), + location: location!(), + }); + } } } } diff --git a/rust/lance-encoding/src/decoder.rs b/rust/lance-encoding/src/decoder.rs index a1bc2e93e25..3c7e51f0cc0 100644 --- a/rust/lance-encoding/src/decoder.rs +++ b/rust/lance-encoding/src/decoder.rs @@ -239,6 +239,7 @@ use crate::compression::{DecompressionStrategy, DefaultDecompressionStrategy}; use crate::data::DataBlock; use crate::encoder::EncodedBatch; use crate::encodings::logical::list::StructuralListScheduler; +use crate::encodings::logical::map::StructuralMapScheduler; use crate::encodings::logical::primitive::StructuralPrimitiveFieldScheduler; use crate::encodings::logical::r#struct::{StructuralStructDecoder, StructuralStructScheduler}; use crate::format::pb::{self, column_encoding}; @@ -773,6 +774,25 @@ impl CoreFieldDecoderStrategy { Ok(Box::new(StructuralListScheduler::new(child_scheduler)) as Box) } + DataType::Map(_, keys_sorted) => { + // TODO: We only support keys_sorted=false for now, + // because converting a rust arrow map field to the python arrow field will + // lose the keys_sorted property. + if *keys_sorted { + return Err(Error::NotSupported { + source: format!("Map data type is not supported with keys_sorted=true now, current value is {}", *keys_sorted).into(), + location: location!(), + }); + } + let entries_child = field + .children + .first() + .expect("Map field must have an entries child"); + let child_scheduler = + self.create_structural_field_scheduler(entries_child, column_infos)?; + Ok(Box::new(StructuralMapScheduler::new(child_scheduler)) + as Box) + } _ => todo!("create_structural_field_scheduler for {}", data_type), } } @@ -1860,21 +1880,24 @@ pub fn create_decode_stream( is_structural: bool, should_validate: bool, rx: mpsc::UnboundedReceiver>, -) -> BoxStream<'static, ReadBatchTask> { +) -> Result> { if is_structural { let arrow_schema = ArrowSchema::from(schema); let structural_decoder = StructuralStructDecoder::new( arrow_schema.fields, should_validate, /*is_root=*/ true, - ); - StructuralBatchDecodeStream::new(rx, batch_size, num_rows, structural_decoder).into_stream() + )?; + Ok( + StructuralBatchDecodeStream::new(rx, batch_size, num_rows, structural_decoder) + .into_stream(), + ) } else { let arrow_schema = ArrowSchema::from(schema); let root_fields = arrow_schema.fields; let simple_struct_decoder = SimpleStructDecoder::new(root_fields, num_rows); - BatchDecodeStream::new(rx, batch_size, num_rows, simple_struct_decoder).into_stream() + Ok(BatchDecodeStream::new(rx, batch_size, num_rows, simple_struct_decoder).into_stream()) } } @@ -1888,28 +1911,28 @@ pub fn create_decode_iterator( should_validate: bool, is_structural: bool, messages: VecDeque>, -) -> Box { +) -> Result> { let arrow_schema = Arc::new(ArrowSchema::from(schema)); let root_fields = arrow_schema.fields.clone(); if is_structural { let simple_struct_decoder = - StructuralStructDecoder::new(root_fields, should_validate, /*is_root=*/ true); - Box::new(BatchDecodeIterator::new( + StructuralStructDecoder::new(root_fields, should_validate, /*is_root=*/ true)?; + Ok(Box::new(BatchDecodeIterator::new( messages, batch_size, num_rows, simple_struct_decoder, arrow_schema, - )) + ))) } else { let root_decoder = SimpleStructDecoder::new(root_fields, num_rows); - Box::new(BatchDecodeIterator::new( + Ok(Box::new(BatchDecodeIterator::new( messages, batch_size, num_rows, root_decoder, arrow_schema, - )) + ))) } } @@ -1934,7 +1957,7 @@ fn create_scheduler_decoder( is_structural, config.decoder_config.validate_on_decode, rx, - ); + )?; let scheduler_handle = tokio::task::spawn(async move { let mut decode_scheduler = match DecodeBatchScheduler::try_new( @@ -2097,7 +2120,7 @@ pub fn schedule_and_decode_blocking( config.decoder_config.validate_on_decode, is_structural, messages.into(), - ); + )?; Ok(decode_iterator) } @@ -2639,7 +2662,7 @@ pub async fn decode_batch( is_structural, should_validate, rx, - ); + )?; decode_stream.next().await.unwrap().task.await } diff --git a/rust/lance-encoding/src/encoder.rs b/rust/lance-encoding/src/encoder.rs index 7db365b5ecf..ee9fc36697d 100644 --- a/rust/lance-encoding/src/encoder.rs +++ b/rust/lance-encoding/src/encoder.rs @@ -30,6 +30,7 @@ use crate::compression_config::CompressionParams; use crate::decoder::PageEncoding; use crate::encodings::logical::blob::{BlobStructuralEncoder, BlobV2StructuralEncoder}; use crate::encodings::logical::list::ListStructuralEncoder; +use crate::encodings::logical::map::MapStructuralEncoder; use crate::encodings::logical::primitive::PrimitiveStructuralEncoder; use crate::encodings::logical::r#struct::StructStructuralEncoder; use crate::repdef::RepDefBuilder; @@ -436,6 +437,65 @@ impl StructuralEncodingStrategy { child_encoder, ))) } + DataType::Map(_, keys_sorted) => { + // TODO: We only support keys_sorted=false for now, + // because converting a rust arrow map field to the python arrow field will + // lose the keys_sorted property. + if keys_sorted { + return Err(Error::NotSupported { + source: format!("Map data type is not supported with keys_sorted=true now, current value is {}", keys_sorted).into(), + location: location!(), + }); + } + if self.version < LanceFileVersion::V2_2 { + return Err(Error::NotSupported { + source: format!( + "Map data type is only supported in Lance file format 2.2+, current version: {}", + self.version + ) + .into(), + location: location!(), + }); + } + let entries_child = field.children.first().ok_or_else(|| Error::Schema { + message: "Map should have an entries child".to_string(), + location: location!(), + })?; + let DataType::Struct(struct_fields) = entries_child.data_type() else { + return Err(Error::Schema { + message: "Map entries field must be a Struct".to_string(), + location: location!(), + }); + }; + if struct_fields.len() < 2 { + return Err(Error::Schema { + message: "Map entries struct must contain both key and value fields" + .to_string(), + location: location!(), + }); + } + let key_field = &struct_fields[0]; + if key_field.is_nullable() { + return Err(Error::Schema { + message: format!( + "Map key field '{}' must be non-nullable according to Arrow Map specification", + key_field.name() + ), + location: location!(), + }); + } + let child_encoder = self.do_create_field_encoder( + _encoding_strategy_root, + entries_child, + column_index, + options, + root_field_metadata, + )?; + Ok(Box::new(MapStructuralEncoder::new( + options.keep_original_array, + child_encoder, + ))) + } DataType::Struct(fields) => { if field.is_packed_struct() || fields.is_empty() { // Both packed structs and empty structs are encoded as primitive diff --git a/rust/lance-encoding/src/encodings/logical.rs b/rust/lance-encoding/src/encodings/logical.rs index e89ef14d956..4b1d186b79a 100644 --- a/rust/lance-encoding/src/encodings/logical.rs +++ b/rust/lance-encoding/src/encodings/logical.rs @@ -3,5 +3,6 @@ pub mod blob; pub mod list; +pub mod map; pub mod primitive; pub mod r#struct; diff --git a/rust/lance-encoding/src/encodings/logical/map.rs b/rust/lance-encoding/src/encodings/logical/map.rs new file mode 100644 index 00000000000..4205a01a892 --- /dev/null +++ b/rust/lance-encoding/src/encodings/logical/map.rs @@ -0,0 +1,672 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +use std::{ops::Range, sync::Arc}; + +use arrow_array::{Array, ArrayRef, MapArray}; +use arrow_schema::DataType; +use futures::future::BoxFuture; +use lance_arrow::deepcopy::deep_copy_nulls; +use lance_core::{Error, Result}; +use snafu::location; + +use crate::{ + decoder::{ + DecodedArray, FilterExpression, ScheduledScanLine, SchedulerContext, + StructuralDecodeArrayTask, StructuralFieldDecoder, StructuralFieldScheduler, + StructuralSchedulingJob, + }, + encoder::{EncodeTask, FieldEncoder, OutOfLineBuffers}, + repdef::RepDefBuilder, +}; + +/// A structural encoder for map fields +/// +/// Map in Arrow is represented as List> +/// The map's offsets are added to the rep/def builder +/// and the map's entries (struct array) are passed to the child encoder +pub struct MapStructuralEncoder { + keep_original_array: bool, + child: Box, +} + +impl MapStructuralEncoder { + pub fn new(keep_original_array: bool, child: Box) -> Self { + Self { + keep_original_array, + child, + } + } +} + +impl FieldEncoder for MapStructuralEncoder { + fn maybe_encode( + &mut self, + array: ArrayRef, + external_buffers: &mut OutOfLineBuffers, + mut repdef: RepDefBuilder, + row_number: u64, + num_rows: u64, + ) -> Result> { + let map_array = array + .as_any() + .downcast_ref::() + .expect("MapEncoder used for non-map data"); + + // Map internally has offsets and entries (struct array) + let entries = map_array.entries(); + let offsets = map_array.offsets(); + + // Add offsets to RepDefBuilder to handle nullability and list structure + if self.keep_original_array { + repdef.add_offsets(offsets.clone(), array.nulls().cloned()) + } else { + repdef.add_offsets(offsets.clone(), deep_copy_nulls(array.nulls())) + }; + + // Pass the entries (struct array) to the child encoder + // Convert to Arc + let entries_arc: ArrayRef = Arc::new(entries.clone()); + self.child + .maybe_encode(entries_arc, external_buffers, repdef, row_number, num_rows) + } + + fn flush(&mut self, external_buffers: &mut OutOfLineBuffers) -> Result> { + self.child.flush(external_buffers) + } + + fn num_columns(&self) -> u32 { + self.child.num_columns() + } + + fn finish( + &mut self, + external_buffers: &mut OutOfLineBuffers, + ) -> BoxFuture<'_, Result>> { + self.child.finish(external_buffers) + } +} + +#[derive(Debug)] +pub struct StructuralMapScheduler { + child: Box, +} + +impl StructuralMapScheduler { + pub fn new(child: Box) -> Self { + Self { child } + } +} + +impl StructuralFieldScheduler for StructuralMapScheduler { + fn schedule_ranges<'a>( + &'a self, + ranges: &[Range], + filter: &FilterExpression, + ) -> Result> { + let child = self.child.schedule_ranges(ranges, filter)?; + + Ok(Box::new(StructuralMapSchedulingJob::new(child))) + } + + fn initialize<'a>( + &'a mut self, + filter: &'a FilterExpression, + context: &'a SchedulerContext, + ) -> BoxFuture<'a, Result<()>> { + self.child.initialize(filter, context) + } +} + +/// Scheduling job for map data +/// +/// Scheduling is handled by the child encoder (struct) and nothing special +/// happens here, similar to list. +#[derive(Debug)] +struct StructuralMapSchedulingJob<'a> { + child: Box, +} + +impl<'a> StructuralMapSchedulingJob<'a> { + fn new(child: Box) -> Self { + Self { child } + } +} + +impl StructuralSchedulingJob for StructuralMapSchedulingJob<'_> { + fn schedule_next(&mut self, context: &mut SchedulerContext) -> Result> { + self.child.schedule_next(context) + } +} + +#[derive(Debug)] +pub struct StructuralMapDecoder { + child: Box, + data_type: DataType, +} + +impl StructuralMapDecoder { + pub fn new(child: Box, data_type: DataType) -> Self { + Self { child, data_type } + } +} + +impl StructuralFieldDecoder for StructuralMapDecoder { + fn accept_page(&mut self, child: crate::decoder::LoadedPageShard) -> Result<()> { + self.child.accept_page(child) + } + + fn drain(&mut self, num_rows: u64) -> Result> { + let child_task = self.child.drain(num_rows)?; + Ok(Box::new(StructuralMapDecodeTask::new( + child_task, + self.data_type.clone(), + ))) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } +} + +#[derive(Debug)] +struct StructuralMapDecodeTask { + child_task: Box, + data_type: DataType, +} + +impl StructuralMapDecodeTask { + fn new(child_task: Box, data_type: DataType) -> Self { + Self { + child_task, + data_type, + } + } +} + +impl StructuralDecodeArrayTask for StructuralMapDecodeTask { + fn decode(self: Box) -> Result { + let DecodedArray { array, mut repdef } = self.child_task.decode()?; + + // Decode the offsets from RepDef + let (offsets, validity) = repdef.unravel_offsets::()?; + + // Extract the entries field and keys_sorted from the map data type + let (entries_field, keys_sorted) = match &self.data_type { + DataType::Map(field, keys_sorted) => { + if *keys_sorted { + return Err(Error::NotSupported { + source: "Map type decoder does not support keys_sorted=true now" + .to_string() + .into(), + location: location!(), + }); + } + (field.clone(), *keys_sorted) + } + _ => { + return Err(Error::Schema { + message: "Map decoder did not have a map field".to_string(), + location: location!(), + }); + } + }; + + // Convert the decoded array to StructArray + let entries = array + .as_any() + .downcast_ref::() + .ok_or_else(|| Error::Schema { + message: "Map entries should be a StructArray".to_string(), + location: location!(), + })? + .clone(); + + // Build the MapArray from offsets, entries, validity, and keys_sorted + let map_array = MapArray::new(entries_field, offsets, entries, validity, keys_sorted); + + Ok(DecodedArray { + array: Arc::new(map_array), + repdef, + }) + } +} + +#[cfg(test)] +mod tests { + use std::{collections::HashMap, sync::Arc}; + + use arrow_array::{ + builder::{Int32Builder, MapBuilder, StringBuilder}, + Array, Int32Array, MapArray, StringArray, StructArray, + }; + use arrow_buffer::{OffsetBuffer, ScalarBuffer}; + use arrow_schema::{DataType, Field, Fields}; + + use crate::encoder::{default_encoding_strategy, ColumnIndexSequence, EncodingOptions}; + use crate::{ + testing::{check_round_trip_encoding_of_data, TestCases}, + version::LanceFileVersion, + }; + use arrow_schema::Field as ArrowField; + use lance_core::datatypes::Field as LanceField; + + fn make_map_type(key_type: DataType, value_type: DataType) -> DataType { + // Note: Arrow MapBuilder uses "keys" and "values" as field names (plural) + let entries = Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("keys", key_type, false), + Field::new("values", value_type, true), + ])), + false, + ); + DataType::Map(Arc::new(entries), false) + } + + #[test_log::test(tokio::test)] + async fn test_simple_map() { + // Create a simple Map + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // Map 1: {"key1": 10, "key2": 20} + map_builder.keys().append_value("key1"); + map_builder.values().append_value(10); + map_builder.keys().append_value("key2"); + map_builder.values().append_value(20); + map_builder.append(true).unwrap(); + + // Map 2: {"key3": 30} + map_builder.keys().append_value("key3"); + map_builder.values().append_value(30); + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_empty_maps() { + // Test maps with empty entries + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // Map 1: {"a": 1} + map_builder.keys().append_value("a"); + map_builder.values().append_value(1); + map_builder.append(true).unwrap(); + + // Map 2: {} (empty) + map_builder.append(true).unwrap(); + + // Map 3: null + map_builder.append(false).unwrap(); + + // Map 4: {} (empty) + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..4) + .with_indices(vec![1]) + .with_indices(vec![2]) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_with_null_values() { + // Test Map with null values + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // Map 1: {"key1": 10, "key2": null} + map_builder.keys().append_value("key1"); + map_builder.values().append_value(10); + map_builder.keys().append_value("key2"); + map_builder.values().append_null(); + map_builder.append(true).unwrap(); + + // Map 2: {"key3": null} + map_builder.keys().append_value("key3"); + map_builder.values().append_null(); + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_indices(vec![0]) + .with_indices(vec![1]) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_in_struct() { + // Test Struct containing Map + // Struct> + + let string_key_builder = StringBuilder::new(); + let string_val_builder = StringBuilder::new(); + let mut map_builder = MapBuilder::new(None, string_key_builder, string_val_builder); + + // First struct: id=1, properties={"name": "Alice", "city": "NYC"} + map_builder.keys().append_value("name"); + map_builder.values().append_value("Alice"); + map_builder.keys().append_value("city"); + map_builder.values().append_value("NYC"); + map_builder.append(true).unwrap(); + + // Second struct: id=2, properties={"name": "Bob"} + map_builder.keys().append_value("name"); + map_builder.values().append_value("Bob"); + map_builder.append(true).unwrap(); + + // Third struct: id=3, properties=null + map_builder.append(false).unwrap(); + + let map_array = Arc::new(map_builder.finish()); + let id_array = Arc::new(Int32Array::from(vec![1, 2, 3])); + + let struct_array = StructArray::new( + Fields::from(vec![ + Field::new("id", DataType::Int32, false), + Field::new( + "properties", + make_map_type(DataType::Utf8, DataType::Utf8), + true, + ), + ]), + vec![id_array, map_array], + None, + ); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_indices(vec![0, 2]) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data( + vec![Arc::new(struct_array)], + &test_cases, + HashMap::new(), + ) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_list_of_maps() { + // Test List> + use arrow_array::builder::ListBuilder; + + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let map_builder = MapBuilder::new(None, string_builder, int_builder); + let mut list_builder = ListBuilder::new(map_builder); + + // List 1: [{"a": 1}, {"b": 2}] + list_builder.values().keys().append_value("a"); + list_builder.values().values().append_value(1); + list_builder.values().append(true).unwrap(); + + list_builder.values().keys().append_value("b"); + list_builder.values().values().append_value(2); + list_builder.values().append(true).unwrap(); + + list_builder.append(true); + + // List 2: [{"c": 3}] + list_builder.values().keys().append_value("c"); + list_builder.values().values().append_value(3); + list_builder.values().append(true).unwrap(); + + list_builder.append(true); + + // List 3: [] (empty list) + list_builder.append(true); + + let list_array = list_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_indices(vec![0, 2]) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(list_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_nested_map() { + // Test Map> + // This is more complex as we need to build nested maps manually + + // Build inner maps first + let inner_string_builder = StringBuilder::new(); + let inner_int_builder = Int32Builder::new(); + let mut inner_map_builder1 = MapBuilder::new(None, inner_string_builder, inner_int_builder); + + // Inner map 1: {"x": 10} + inner_map_builder1.keys().append_value("x"); + inner_map_builder1.values().append_value(10); + inner_map_builder1.append(true).unwrap(); + + // Inner map 2: {"y": 20, "z": 30} + inner_map_builder1.keys().append_value("y"); + inner_map_builder1.values().append_value(20); + inner_map_builder1.keys().append_value("z"); + inner_map_builder1.values().append_value(30); + inner_map_builder1.append(true).unwrap(); + + let inner_maps = Arc::new(inner_map_builder1.finish()); + + // Build outer map keys + let outer_keys = Arc::new(StringArray::from(vec!["key1", "key2"])); + + // Build outer map structure + let entries_struct = StructArray::new( + Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + make_map_type(DataType::Utf8, DataType::Int32), + true, + ), + ]), + vec![outer_keys, inner_maps], + None, + ); + + let offsets = OffsetBuffer::new(ScalarBuffer::::from(vec![0, 2])); + let entries_field = Field::new("entries", entries_struct.data_type().clone(), false); + + let outer_map = MapArray::new( + Arc::new(entries_field), + offsets, + entries_struct, + None, + false, + ); + + let test_cases = TestCases::default() + .with_range(0..1) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(outer_map)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_different_key_types() { + // Test Map (integer keys) + let int_builder = Int32Builder::new(); + let string_builder = StringBuilder::new(); + let mut map_builder = MapBuilder::new(None, int_builder, string_builder); + + // Map 1: {1: "one", 2: "two"} + map_builder.keys().append_value(1); + map_builder.values().append_value("one"); + map_builder.keys().append_value(2); + map_builder.values().append_value("two"); + map_builder.append(true).unwrap(); + + // Map 2: {3: "three"} + map_builder.keys().append_value(3); + map_builder.values().append_value("three"); + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_indices(vec![0, 1]) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_with_extreme_sizes() { + // Test maps with large number of entries + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // Create a map with many entries + for i in 0..100 { + map_builder.keys().append_value(format!("key{}", i)); + map_builder.values().append_value(i); + } + map_builder.append(true).unwrap(); + + // Create a second map with no entries + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_all_null() { + // Test map where all entries are null + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // All null maps + map_builder.append(false).unwrap(); // null + map_builder.append(false).unwrap(); // null + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..2) + .with_min_file_version(LanceFileVersion::V2_2); + + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test_log::test(tokio::test)] + async fn test_map_encoder_keep_original_array_scenarios() { + // Test scenarios that highlight the difference between keep_original_array=true/false + // This test focuses on round-trip behavior which should be equivalent in both cases + let string_builder = StringBuilder::new(); + let int_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, string_builder, int_builder); + + // Create a map with mixed null and non-null values to test both scenarios + // Map 1: {"key1": 10, "key2": null} + map_builder.keys().append_value("key1"); + map_builder.values().append_value(10); + map_builder.keys().append_value("key2"); + map_builder.values().append_null(); + map_builder.append(true).unwrap(); + + // Map 2: null + map_builder.append(false).unwrap(); + + // Map 3: {"key3": 30} + map_builder.keys().append_value("key3"); + map_builder.values().append_value(30); + map_builder.append(true).unwrap(); + + let map_array = map_builder.finish(); + + let test_cases = TestCases::default() + .with_range(0..3) + .with_indices(vec![0, 1, 2]) + .with_min_file_version(LanceFileVersion::V2_2); + + // This test ensures that regardless of the internal keep_original_array setting, + // the end-to-end behavior produces equivalent results + check_round_trip_encoding_of_data(vec![Arc::new(map_array)], &test_cases, HashMap::new()) + .await; + } + + #[test] + fn test_map_not_supported_write_in_v2_1() { + // Create a map field using Arrow Field first, then convert to Lance Field + let map_arrow_field = ArrowField::new( + "map_field", + make_map_type(DataType::Utf8, DataType::Int32), + true, + ); + let map_field = LanceField::try_from(&map_arrow_field).unwrap(); + + // Test encoder: Try to create encoder with V2_1 version - should fail + let encoder_strategy = default_encoding_strategy(LanceFileVersion::V2_1); + let mut column_index = ColumnIndexSequence::default(); + let options = EncodingOptions::default(); + + let encoder_result = encoder_strategy.create_field_encoder( + encoder_strategy.as_ref(), + &map_field, + &mut column_index, + &options, + ); + + assert!( + encoder_result.is_err(), + "Map type should not be supported in V2_1 for encoder" + ); + let Err(encoder_err) = encoder_result else { + panic!("Expected error but got Ok") + }; + + let encoder_err_msg = format!("{}", encoder_err); + assert!( + encoder_err_msg.contains("2.2"), + "Encoder error message should mention version 2.2, got: {}", + encoder_err_msg + ); + assert!( + encoder_err_msg.contains("Map data type"), + "Encoder error message should mention Map data type, got: {}", + encoder_err_msg + ); + } +} diff --git a/rust/lance-encoding/src/encodings/logical/struct.rs b/rust/lance-encoding/src/encodings/logical/struct.rs index 0da9ec38d2d..0a53ec9a21c 100644 --- a/rust/lance-encoding/src/encodings/logical/struct.rs +++ b/rust/lance-encoding/src/encodings/logical/struct.rs @@ -7,6 +7,10 @@ use std::{ sync::Arc, }; +use super::{ + list::StructuralListDecoder, map::StructuralMapDecoder, + primitive::StructuralPrimitiveFieldDecoder, +}; use crate::{ decoder::{ DecodedArray, FilterExpression, LoadedPageShard, NextDecodeTask, PageEncoding, @@ -27,10 +31,9 @@ use futures::{ use itertools::Itertools; use lance_arrow::FieldExt; use lance_arrow::{deepcopy::deep_copy_nulls, r#struct::StructArrayExt}; -use lance_core::Result; +use lance_core::{Error, Result}; use log::trace; - -use super::{list::StructuralListDecoder, primitive::StructuralPrimitiveFieldDecoder}; +use snafu::location; #[derive(Debug)] struct StructuralSchedulingJobWithStatus<'a> { @@ -237,46 +240,63 @@ pub struct StructuralStructDecoder { } impl StructuralStructDecoder { - pub fn new(fields: Fields, should_validate: bool, is_root: bool) -> Self { + pub fn new(fields: Fields, should_validate: bool, is_root: bool) -> Result { let children = fields .iter() .map(|field| Self::field_to_decoder(field, should_validate)) - .collect(); + .collect::>>()?; let data_type = DataType::Struct(fields.clone()); - Self { + Ok(Self { data_type, children, child_fields: fields, is_root, - } + }) } fn field_to_decoder( field: &Arc, should_validate: bool, - ) -> Box { + ) -> Result> { match field.data_type() { DataType::Struct(fields) => { if field.is_packed_struct() || field.is_blob() { let decoder = StructuralPrimitiveFieldDecoder::new(&field.clone(), should_validate); - Box::new(decoder) + Ok(Box::new(decoder)) } else { - Box::new(Self::new(fields.clone(), should_validate, false)) + Ok(Box::new(Self::new(fields.clone(), should_validate, false)?)) } } DataType::List(child_field) | DataType::LargeList(child_field) => { - let child_decoder = Self::field_to_decoder(child_field, should_validate); - Box::new(StructuralListDecoder::new( + let child_decoder = Self::field_to_decoder(child_field, should_validate)?; + Ok(Box::new(StructuralListDecoder::new( + child_decoder, + field.data_type().clone(), + ))) + } + DataType::Map(entries_field, keys_sorted) => { + if *keys_sorted { + return Err(Error::NotSupported { + source: "Map data type with keys_sorted=true is not supported yet" + .to_string() + .into(), + location: location!(), + }); + } + let child_decoder = Self::field_to_decoder(entries_field, should_validate)?; + Ok(Box::new(StructuralMapDecoder::new( child_decoder, field.data_type().clone(), - )) + ))) } DataType::RunEndEncoded(_, _) => todo!(), DataType::ListView(_) | DataType::LargeListView(_) => todo!(), - DataType::Map(_, _) => todo!(), DataType::Union(_, _) => todo!(), - _ => Box::new(StructuralPrimitiveFieldDecoder::new(field, should_validate)), + _ => Ok(Box::new(StructuralPrimitiveFieldDecoder::new( + field, + should_validate, + ))), } } diff --git a/rust/lance-encoding/src/testing.rs b/rust/lance-encoding/src/testing.rs index 3b6c43a3d2a..f6bc8cda268 100644 --- a/rust/lance-encoding/src/testing.rs +++ b/rust/lance-encoding/src/testing.rs @@ -126,6 +126,14 @@ fn column_indices_from_schema_helper( is_structural_encoding, ); } + DataType::Map(entries, _) => { + column_indices_from_schema_helper( + std::slice::from_ref(entries), + column_indices, + column_counter, + is_structural_encoding, + ); + } DataType::FixedSizeList(inner, _) => { // FSL(primitive) does not get its own column in either approach column_indices_from_schema_helper( @@ -209,7 +217,8 @@ async fn test_decode( is_structural_encoding, /*should_validate=*/ true, rx, - ); + ) + .unwrap(); let mut offset = 0; while let Some(batch) = decode_stream.next().await { diff --git a/rust/lance-namespace/src/schema.rs b/rust/lance-namespace/src/schema.rs index b51e0cc2bd5..e0b66f52a07 100644 --- a/rust/lance-namespace/src/schema.rs +++ b/rust/lance-namespace/src/schema.rs @@ -181,11 +181,36 @@ fn arrow_type_to_json(data_type: &DataType) -> Result { arrow_type_to_json(value_type) } + DataType::Map(entries_field, keys_sorted) => { + if *keys_sorted { + return Err(Error::Namespace { + source: format!( + "Map types with keys_sorted=true are not yet supported for JSON conversion: {:?}", + data_type + ) + .into(), + location: Location::new(file!(), line!(), column!()), + }); + } + let inner_type = arrow_type_to_json(entries_field.data_type())?; + let inner_field = JsonArrowField { + name: entries_field.name().clone(), + nullable: entries_field.is_nullable(), + r#type: Box::new(inner_type), + metadata: if entries_field.metadata().is_empty() { + None + } else { + Some(entries_field.metadata().clone()) + }, + }; + Ok(JsonArrowDataType { + r#type: "map".to_string(), + fields: Some(vec![inner_field]), + length: None, + }) + } + // Unsupported types - DataType::Map(_, _) => Err(Error::Namespace { - source: "Map type is not supported by Lance".into(), - location: Location::new(file!(), line!(), column!()), - }), DataType::RunEndEncoded(_, _) => Err(Error::Namespace { source: format!( "RunEndEncoded type is not yet supported for JSON conversion: {:?}", @@ -431,7 +456,7 @@ mod tests { } #[test] - fn test_map_type_unsupported() { + fn test_map_type_supported() { use arrow::datatypes::Field; let key_field = Field::new("keys", DataType::Utf8, false); @@ -446,11 +471,15 @@ mod tests { ); let result = arrow_type_to_json(&map_type); - assert!(result.is_err()); - assert!(result - .unwrap_err() - .to_string() - .contains("Map type is not supported")); + assert!(result.is_ok()); + let json_type = result.unwrap(); + assert_eq!(json_type.r#type, "map"); + assert!(json_type.fields.is_some()); + + let fields = json_type.fields.unwrap(); + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name, "entries"); + assert_eq!(fields[0].r#type.r#type, "struct"); } #[test]