diff --git a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java index 15d670d45fd6..2631983d42a5 100644 --- a/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java +++ b/core/src/main/java/org/apache/iceberg/avro/AvroSchemaVisitor.java @@ -58,7 +58,7 @@ public static T visit(Schema schema, AvroSchemaVisitor visitor) { return visitor.union(schema, options); case ARRAY: - if (schema.getLogicalType() instanceof LogicalMap || AvroSchemaUtil.isKeyValueSchema(schema.getElementType())) { + if (schema.getLogicalType() instanceof LogicalMap) { return visitor.array(schema, visit(schema.getElementType(), visitor)); } else { return visitor.array(schema, visitWithName("element", schema.getElementType(), visitor)); diff --git a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java index 0fd52124d7e3..eb13b4bd65a9 100644 --- a/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java +++ b/core/src/main/java/org/apache/iceberg/avro/PruneColumns.java @@ -121,7 +121,7 @@ public Schema union(Schema union, List options) { @Override @SuppressWarnings("checkstyle:CyclomaticComplexity") public Schema array(Schema array, Schema element) { - if (array.getLogicalType() instanceof LogicalMap || AvroSchemaUtil.isKeyValueSchema(array.getElementType())) { + if (array.getLogicalType() instanceof LogicalMap) { Schema keyValue = array.getElementType(); Integer keyId = AvroSchemaUtil.getFieldId(keyValue.getField("key"), nameMapping, fieldNames()); Integer valueId = AvroSchemaUtil.getFieldId(keyValue.getField("value"), nameMapping, fieldNames()); diff --git a/core/src/test/java/org/apache/iceberg/avro/RemoveIds.java b/core/src/test/java/org/apache/iceberg/avro/RemoveIds.java index da13e37c162e..b26b1c2e8d91 100644 --- a/core/src/test/java/org/apache/iceberg/avro/RemoveIds.java +++ b/core/src/test/java/org/apache/iceberg/avro/RemoveIds.java @@ -38,12 +38,26 @@ public Schema record(Schema record, List names, List types) { @Override public Schema map(Schema map, Schema valueType) { - return Schema.createMap(valueType); + Schema result = Schema.createMap(valueType); + for (Map.Entry prop : map.getObjectProps().entrySet()) { + String key = prop.getKey(); + if (!key.equals(AvroSchemaUtil.KEY_ID_PROP) && !key.equals(AvroSchemaUtil.VALUE_ID_PROP)) { + result.addProp(key, prop.getValue()); + } + } + return result; } @Override public Schema array(Schema array, Schema element) { - return Schema.createArray(element); + Schema result = Schema.createArray(element); + for (Map.Entry prop : array.getObjectProps().entrySet()) { + String key = prop.getKey(); + if (!key.equals(AvroSchemaUtil.ELEMENT_ID_PROP)) { + result.addProp(key, prop.getValue()); + } + } + return result; } @Override diff --git a/core/src/test/java/org/apache/iceberg/avro/TestAvroNameMapping.java b/core/src/test/java/org/apache/iceberg/avro/TestAvroNameMapping.java index 4682a1f048c6..b8e1d419e423 100644 --- a/core/src/test/java/org/apache/iceberg/avro/TestAvroNameMapping.java +++ b/core/src/test/java/org/apache/iceberg/avro/TestAvroNameMapping.java @@ -38,6 +38,7 @@ import org.apache.iceberg.mapping.MappedFields; import org.apache.iceberg.mapping.MappingUtil; import org.apache.iceberg.mapping.NameMapping; +import org.apache.iceberg.types.Comparators; import org.apache.iceberg.types.Types; import org.junit.Assert; import org.junit.Test; @@ -92,6 +93,70 @@ public void testMapProjections() throws IOException { Assert.assertNull("location.value.long, should not be read", projectedL1.get("long_r2")); } + @Test + public void testComplexMapKeys() throws IOException { + Schema writeSchema = new Schema( + Types.NestedField.required(5, "location", Types.MapType.ofRequired(6, 7, + Types.StructType.of( + Types.NestedField.required(3, "k1", Types.StringType.get()), + Types.NestedField.required(4, "k2", Types.StringType.get()) + ), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.optional(2, "long", Types.FloatType.get()) + ) + ))); + + Record record = new Record(AvroSchemaUtil.convert(writeSchema, "table")); + org.apache.avro.Schema locationSchema = record.getSchema().getField("location").schema(); + Record locationElement = new Record(locationSchema.getElementType()); + Record locationKey = new Record(locationElement.getSchema().getField("key").schema()); + Record locationValue = new Record(locationElement.getSchema().getField("value").schema()); + + locationKey.put("k1", "k1"); + locationKey.put("k2", "k2"); + locationValue.put("lat", 52.995143f); + locationValue.put("long", -1.539054f); + locationElement.put("key", locationKey); + locationElement.put("value", locationValue); + record.put("location", ImmutableList.of(locationElement)); + + // project a subset of the map's value columns in NameMapping + NameMapping nameMapping = MappingUtil.create(new Schema( + Types.NestedField.required(5, "location", Types.MapType.ofOptional(6, 7, + Types.StructType.of( + Types.NestedField.required(3, "k1", Types.StringType.get()), + Types.NestedField.optional(4, "k2", Types.StringType.get()) + ), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()) + ) + )))); + + Schema readSchema = new Schema( + Types.NestedField.required(5, "location", Types.MapType.ofOptional(6, 7, + Types.StructType.of( + Types.NestedField.required(3, "k1", Types.StringType.get()), + Types.NestedField.optional(4, "k2", Types.StringType.get()) + ), + Types.StructType.of( + Types.NestedField.required(1, "lat", Types.FloatType.get()), + Types.NestedField.optional(2, "long", Types.FloatType.get()) + ) + ))); + + Record projected = writeAndRead(writeSchema, readSchema, record, nameMapping); + // The data is read back as a map + Map projectedLocation = (Map) projected.get("location"); + Record projectedKey = projectedLocation.keySet().iterator().next(); + Record projectedValue = projectedLocation.values().iterator().next(); + Assert.assertEquals(0, Comparators.charSequences().compare("k1", (CharSequence) projectedKey.get("k1"))); + Assert.assertEquals(0, Comparators.charSequences().compare("k2", (CharSequence) projectedKey.get("k2"))); + Assert.assertEquals(52.995143f, projectedValue.get("lat")); + Assert.assertNotNull(projectedValue.getSchema().getField("long_r2")); + Assert.assertNull(projectedValue.get("long_r2")); + } + @Test public void testMissingRequiredFields() { Schema writeSchema = new Schema(