diff --git a/.palantir/revapi.yml b/.palantir/revapi.yml index 21abbeb19ab7..0df4fc1a2747 100644 --- a/.palantir/revapi.yml +++ b/.palantir/revapi.yml @@ -745,6 +745,14 @@ acceptedBreaks: \ of type erasure and the original type is always returned" "1.3.0": org.apache.iceberg:iceberg-api: + - code: "java.class.defaultSerializationChanged" + old: "class org.apache.iceberg.PartitionField" + new: "class org.apache.iceberg.PartitionField" + justification: "Added a new field" + - code: "java.class.defaultSerializationChanged" + old: "class org.apache.iceberg.SortField" + new: "class org.apache.iceberg.SortField" + justification: "Added a new field" - code: "java.class.removed" old: "class org.apache.iceberg.actions.ImmutableDeleteOrphanFiles" justification: "Moving from iceberg-api to iceberg-core" diff --git a/api/src/main/java/org/apache/iceberg/Accessors.java b/api/src/main/java/org/apache/iceberg/Accessors.java index 08233624f244..f3e738b433d0 100644 --- a/api/src/main/java/org/apache/iceberg/Accessors.java +++ b/api/src/main/java/org/apache/iceberg/Accessors.java @@ -18,12 +18,16 @@ */ package org.apache.iceberg; +import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; +import org.apache.iceberg.util.StructProjection; /** * Position2Accessor and Position3Accessor here is an optimization. For a nested schema like: @@ -55,6 +59,49 @@ static Map> forSchema(Schema schema) { return TypeUtil.visit(schema, new BuildPositionAccessors()); } + static Accessor projectFields(Schema schema, int[] fieldIds) { + Preconditions.checkArgument( + fieldIds != null && fieldIds.length > 1, "selected fields must be non-empty"); + List fields = + Arrays.stream(fieldIds).mapToObj(schema::findField).collect(Collectors.toList()); + Types.StructType projected = Types.StructType.of(fields); + // todo: handles all the projected fields are deleted, which should always produce an null + // struct. + StructProjection projection = StructProjection.createAllowMissing(schema.asStruct(), projected); + return new StructProjectionAccessor(projection, projected); + } + + private static class StructProjectionAccessor implements Accessor { + private final StructProjection projection; + private final Types.StructType type; + + StructProjectionAccessor(StructProjection projection, Types.StructType type) { + this.projection = projection; + this.type = type; + } + + @Override + public Object get(StructLike row) { + return projection.wrap(row); + } + + @Override + public Type type() { + return type; + } + + public Class javaClass() { + return type.typeId().javaClass(); + } + + @Override + public String toString() { + String[] fieldNames = + type.fields().stream().map(Types.NestedField::name).toArray(String[]::new); + return "Accessor(fieldNames=" + Arrays.toString(fieldNames) + ", type=" + type + ")"; + } + } + private static class PositionAccessor implements Accessor { private final int position; private final Type type; diff --git a/api/src/main/java/org/apache/iceberg/PartitionField.java b/api/src/main/java/org/apache/iceberg/PartitionField.java index 3ed765a89834..6ecf55d172c2 100644 --- a/api/src/main/java/org/apache/iceberg/PartitionField.java +++ b/api/src/main/java/org/apache/iceberg/PartitionField.java @@ -19,18 +19,32 @@ package org.apache.iceberg; import java.io.Serializable; +import java.util.Arrays; import org.apache.iceberg.relocated.com.google.common.base.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.transforms.Transform; /** Represents a single field in a {@link PartitionSpec}. */ public class PartitionField implements Serializable { private final int sourceId; + private final int[] sourceIds; private final int fieldId; private final String name; private final Transform transform; PartitionField(int sourceId, int fieldId, String name, Transform transform) { this.sourceId = sourceId; + this.sourceIds = new int[] {sourceId}; + this.fieldId = fieldId; + this.name = name; + this.transform = transform; + } + + PartitionField(int[] sourceIds, int fieldId, String name, Transform transform) { + Preconditions.checkArgument( + sourceIds != null && sourceIds.length >= 1, "At least one source id should be provided"); + this.sourceId = sourceIds.length > 1 ? -1 : sourceIds[0]; + this.sourceIds = sourceIds; this.fieldId = fieldId; this.name = name; this.transform = transform; @@ -41,6 +55,10 @@ public int sourceId() { return sourceId; } + public int[] sourceIds() { + return sourceIds; + } + /** Returns the partition field id across all the table metadata's partition specs. */ public int fieldId() { return fieldId; @@ -58,7 +76,11 @@ public String name() { @Override public String toString() { - return fieldId + ": " + name + ": " + transform + "(" + sourceId + ")"; + if (sourceIds.length == 1) { + return fieldId + ": " + name + ": " + transform + "(" + sourceId + ")"; + } else { + return fieldId + ": " + name + ": " + transform + "(" + Arrays.toString(sourceIds) + ")"; + } } @Override @@ -71,6 +93,7 @@ public boolean equals(Object other) { PartitionField that = (PartitionField) other; return sourceId == that.sourceId + && Arrays.equals(sourceIds, that.sourceIds) && fieldId == that.fieldId && name.equals(that.name) && transform.toString().equals(that.transform.toString()); @@ -78,6 +101,6 @@ public boolean equals(Object other) { @Override public int hashCode() { - return Objects.hashCode(sourceId, fieldId, name, transform); + return Objects.hashCode(sourceId, sourceIds, fieldId, name, transform); } } diff --git a/api/src/main/java/org/apache/iceberg/PartitionKey.java b/api/src/main/java/org/apache/iceberg/PartitionKey.java index fc56d1a45347..b626ee96bee7 100644 --- a/api/src/main/java/org/apache/iceberg/PartitionKey.java +++ b/api/src/main/java/org/apache/iceberg/PartitionKey.java @@ -53,7 +53,12 @@ public PartitionKey(PartitionSpec spec, Schema inputSchema) { Schema schema = spec.schema(); for (int i = 0; i < size; i += 1) { PartitionField field = fields.get(i); - Accessor accessor = inputSchema.accessorForField(field.sourceId()); + Accessor accessor; + if (field.sourceIds().length == 1) { + accessor = inputSchema.accessorForField(field.sourceId()); + } else { + accessor = inputSchema.accessorForFields(field.sourceIds()); + } Preconditions.checkArgument( accessor != null, "Cannot build accessor for field: " + schema.findField(field.sourceId())); diff --git a/api/src/main/java/org/apache/iceberg/PartitionSpec.java b/api/src/main/java/org/apache/iceberg/PartitionSpec.java index a31cfd76583b..ea2645d33469 100644 --- a/api/src/main/java/org/apache/iceberg/PartitionSpec.java +++ b/api/src/main/java/org/apache/iceberg/PartitionSpec.java @@ -28,6 +28,7 @@ import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Joiner; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; import org.apache.iceberg.relocated.com.google.common.collect.ListMultimap; @@ -103,7 +104,11 @@ public UnboundPartitionSpec toUnbound() { for (PartitionField field : fields) { builder.addField( - field.transform().toString(), field.sourceId(), field.fieldId(), field.name()); + field.transform().toString(), + field.sourceId(), + field.sourceIds(), + field.fieldId(), + field.name()); } return builder.build(); @@ -354,24 +359,20 @@ private int nextFieldId() { return lastAssignedFieldId.incrementAndGet(); } - private void checkAndAddPartitionName(String name) { - checkAndAddPartitionName(name, null); - } - Builder checkConflicts(boolean check) { checkConflicts = check; return this; } - private void checkAndAddPartitionName(String name, Integer sourceColumnId) { + private void checkAndAddPartitionName(String name, int... sourceIds) { Types.NestedField schemaField = schema.findField(name); if (checkConflicts) { - if (sourceColumnId != null) { + if (sourceIds.length == 1) { // for identity transform case we allow conflicts between partition and schema field name // as // long as they are sourced from the same schema field Preconditions.checkArgument( - schemaField == null || schemaField.fieldId() == sourceColumnId, + schemaField == null || schemaField.fieldId() == sourceIds[0], "Cannot create identity partition sourced from different field in schema: %s", name); } else { @@ -504,6 +505,20 @@ public Builder hour(String sourceName) { return hour(sourceName, sourceName + "_hour"); } + public Builder bucket(String[] sourceNames, int numBuckets, String targetName) { + Types.NestedField[] sourceColumns = new Types.NestedField[sourceNames.length]; + int[] sourceColumnIds = new int[sourceNames.length]; + for (int i = 0; i < sourceNames.length; i++) { + sourceColumns[i] = findSourceColumn(sourceNames[i]); + sourceColumnIds[i] = sourceColumns[i].fieldId(); + } + Types.StructType type = Types.StructType.of(sourceColumns); + fields.add( + new PartitionField( + sourceColumnIds, nextFieldId(), targetName, Transforms.bucket(type, numBuckets))); + return this; + } + public Builder bucket(String sourceName, int numBuckets, String targetName) { checkAndAddPartitionName(targetName); Types.NestedField sourceColumn = findSourceColumn(sourceName); @@ -520,6 +535,16 @@ public Builder bucket(String sourceName, int numBuckets) { return bucket(sourceName, numBuckets, sourceName + "_bucket"); } + public Builder bucket(String[] sourceNames, int numBuckets) { + Preconditions.checkArgument(sourceNames != null && sourceNames.length >= 1); + if (sourceNames.length == 1) { + return bucket(sourceNames[0], numBuckets); + } else { + String targetName = Joiner.on("_").join(sourceNames); + return bucket(sourceNames, numBuckets, targetName + "_bucket"); + } + } + public Builder truncate(String sourceName, int width, String targetName) { checkAndAddPartitionName(targetName); Types.NestedField sourceColumn = findSourceColumn(sourceName); @@ -556,6 +581,10 @@ Builder add(int sourceId, String name, Transform transform) { return add(sourceId, nextFieldId(), name, transform); } + Builder add(int[] sourceIds, String name, Transform transform) { + return add(sourceIds, nextFieldId(), name, transform); + } + Builder add(int sourceId, int fieldId, String name, Transform transform) { checkAndAddPartitionName(name, sourceId); fields.add(new PartitionField(sourceId, fieldId, name, transform)); @@ -563,6 +592,13 @@ Builder add(int sourceId, int fieldId, String name, Transform transform) { return this; } + Builder add(int[] sourceIds, int fieldId, String name, Transform transform) { + checkAndAddPartitionName(name, sourceIds); + fields.add(new PartitionField(sourceIds, fieldId, name, transform)); + lastAssignedFieldId.getAndAccumulate(fieldId, Math::max); + return this; + } + public PartitionSpec build() { PartitionSpec spec = buildUnchecked(); checkCompatibility(spec, schema); @@ -576,25 +612,28 @@ PartitionSpec buildUnchecked() { static void checkCompatibility(PartitionSpec spec, Schema schema) { for (PartitionField field : spec.fields) { - Type sourceType = schema.findType(field.sourceId()); Transform transform = field.transform(); - // In the case of a Version 1 partition-spec field gets deleted, - // it is replaced with a void transform, see: - // https://iceberg.apache.org/spec/#partition-transforms - // We don't care about the source type since a VoidTransform is always compatible and skip the - // checks - if (!transform.equals(Transforms.alwaysNull())) { - ValidationException.check( - sourceType != null, "Cannot find source column for partition field: %s", field); - ValidationException.check( - sourceType.isPrimitiveType(), - "Cannot partition by non-primitive source field: %s", - sourceType); - ValidationException.check( - transform.canTransform(sourceType), - "Invalid source type %s for transform: %s", - sourceType, - transform); + for (int id : field.sourceIds()) { + Type sourceType = schema.findType(id); + // In the case of a Version 1 partition-spec field gets deleted, + // it is replaced with a void transform, see: + // https://iceberg.apache.org/spec/#partition-transforms + // We don't care about the source type since a VoidTransform is always compatible and skip + // the + // checks + if (!transform.equals(Transforms.alwaysNull())) { + ValidationException.check( + sourceType != null, "Cannot find source column for partition field: %s", field); + ValidationException.check( + sourceType.isPrimitiveType(), + "Cannot partition by non-primitive source field: %s", + sourceType); + ValidationException.check( + transform.canTransform(sourceType), + "Invalid source type %s for transform: %s", + sourceType, + transform); + } } } } diff --git a/api/src/main/java/org/apache/iceberg/Schema.java b/api/src/main/java/org/apache/iceberg/Schema.java index 5e024b7c1c29..52ae17715927 100644 --- a/api/src/main/java/org/apache/iceberg/Schema.java +++ b/api/src/main/java/org/apache/iceberg/Schema.java @@ -413,6 +413,13 @@ public Accessor accessorForField(int id) { return lazyIdToAccessor().get(id); } + public Accessor accessorForFields(int[] ids) { + if (ids.length == 1) { + return accessorForField(ids[0]); + } + return Accessors.projectFields(this, ids); + } + /** * Creates a projection schema for a subset of columns, selected by name. * diff --git a/api/src/main/java/org/apache/iceberg/SortField.java b/api/src/main/java/org/apache/iceberg/SortField.java index d7f110a26e3f..d9ed1b70c925 100644 --- a/api/src/main/java/org/apache/iceberg/SortField.java +++ b/api/src/main/java/org/apache/iceberg/SortField.java @@ -19,7 +19,9 @@ package org.apache.iceberg; import java.io.Serializable; +import java.util.Arrays; import java.util.Objects; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.transforms.Transform; /** A field in a {@link SortOrder}. */ @@ -27,12 +29,25 @@ public class SortField implements Serializable { private final Transform transform; private final int sourceId; + private final int[] sourceIds; private final SortDirection direction; private final NullOrder nullOrder; SortField(Transform transform, int sourceId, SortDirection direction, NullOrder nullOrder) { this.transform = transform; this.sourceId = sourceId; + this.sourceIds = new int[] {sourceId}; + this.direction = direction; + this.nullOrder = nullOrder; + } + + SortField( + Transform transform, int[] sourceIds, SortDirection direction, NullOrder nullOrder) { + Preconditions.checkArgument( + sourceIds != null && sourceIds.length >= 1, "at least one source id should be provided"); + this.transform = transform; + this.sourceId = sourceIds.length > 1 ? -1 : sourceIds[0]; + this.sourceIds = sourceIds; this.direction = direction; this.nullOrder = nullOrder; } @@ -54,6 +69,10 @@ public int sourceId() { return sourceId; } + public int[] sourceIds() { + return sourceIds; + } + /** Returns the sort direction */ public SortDirection direction() { return direction; @@ -74,6 +93,7 @@ public boolean satisfies(SortField other) { if (Objects.equals(this, other)) { return true; } else if (sourceId != other.sourceId + || !Arrays.equals(sourceIds, other.sourceIds) || direction != other.direction || nullOrder != other.nullOrder) { return false; @@ -84,7 +104,11 @@ public boolean satisfies(SortField other) { @Override public String toString() { - return transform + "(" + sourceId + ") " + direction + " " + nullOrder; + if (sourceIds.length == 1) { + return transform + "(" + sourceId + ") " + direction + " " + nullOrder; + } else { + return transform + "(" + Arrays.toString(sourceIds) + ")" + direction + " " + nullOrder; + } } @Override @@ -98,6 +122,7 @@ public boolean equals(Object other) { SortField that = (SortField) other; return transform.toString().equals(that.transform.toString()) && sourceId == that.sourceId + && Arrays.equals(sourceIds, that.sourceIds) && direction == that.direction && nullOrder == that.nullOrder; } diff --git a/api/src/main/java/org/apache/iceberg/SortOrder.java b/api/src/main/java/org/apache/iceberg/SortOrder.java index d0041cefc1c4..4e09589b3370 100644 --- a/api/src/main/java/org/apache/iceberg/SortOrder.java +++ b/api/src/main/java/org/apache/iceberg/SortOrder.java @@ -22,6 +22,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.BoundReference; @@ -36,6 +37,7 @@ import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.Transforms; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; /** A sort order that defines how data and delete files should be ordered in a table. */ public class SortOrder implements Serializable { @@ -127,7 +129,11 @@ public UnboundSortOrder toUnbound() { for (SortField field : fields) { builder.addSortField( - field.transform().toString(), field.sourceId(), field.direction(), field.nullOrder()); + field.transform().toString(), + field.sourceId(), + field.sourceIds(), + field.direction(), + field.nullOrder()); } return builder.build(); @@ -247,8 +253,13 @@ private Builder addSortField(Term term, SortDirection direction, NullOrder nullO // ValidationException is thrown by bind if binding fails so we assume that boundTerm is // correct BoundTerm boundTerm = ((UnboundTerm) term).bind(schema.asStruct(), caseSensitive); - int sourceId = boundTerm.ref().fieldId(); - SortField sortField = new SortField(toTransform(boundTerm), sourceId, direction, nullOrder); + int[] sourceIds = + boundTerm.refs().stream() + .map(BoundReference::field) + .map(Types.NestedField::fieldId) + .mapToInt(x -> x) + .toArray(); + SortField sortField = new SortField(toTransform(boundTerm), sourceIds, direction, nullOrder); fields.add(sortField); return this; } @@ -297,13 +308,20 @@ SortOrder buildUnchecked() { public static void checkCompatibility(SortOrder sortOrder, Schema schema) { for (SortField field : sortOrder.fields) { - Type sourceType = schema.findType(field.sourceId()); - ValidationException.check( - sourceType != null, "Cannot find source column for sort field: %s", field); - ValidationException.check( - sourceType.isPrimitiveType(), - "Cannot sort by non-primitive source field: %s", - sourceType); + int[] sourceIds = field.sourceIds(); + for (int sourceId : sourceIds) { + Type sourceType = schema.findType(sourceId); + ValidationException.check( + sourceType != null, "Cannot find source column for sort field: %s", field); + ValidationException.check( + sourceType.isPrimitiveType(), + "Cannot sort by non-primitive source field: %s", + sourceType); + } + List sourceFields = + Arrays.stream(sourceIds).mapToObj(schema::findField).collect(Collectors.toList()); + Type sourceType = + sourceIds.length == 1 ? sourceFields.get(0).type() : Types.StructType.of(sourceFields); ValidationException.check( field.transform().canTransform(sourceType), "Invalid source type %s for transform: %s", diff --git a/api/src/main/java/org/apache/iceberg/UnboundPartitionSpec.java b/api/src/main/java/org/apache/iceberg/UnboundPartitionSpec.java index cc8526f9072c..359ac3ae24cb 100644 --- a/api/src/main/java/org/apache/iceberg/UnboundPartitionSpec.java +++ b/api/src/main/java/org/apache/iceberg/UnboundPartitionSpec.java @@ -19,6 +19,7 @@ package org.apache.iceberg; import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.Transforms; @@ -54,7 +55,7 @@ private PartitionSpec.Builder copyToBuilder(Schema schema) { PartitionSpec.Builder builder = PartitionSpec.builderFor(schema).withSpecId(specId); for (UnboundPartitionField field : fields) { - Type fieldType = schema.findType(field.sourceId); + Type fieldType = schema.findType(field.sourceId()); Transform transform; if (fieldType != null) { transform = Transforms.fromString(fieldType, field.transform.toString()); @@ -62,9 +63,9 @@ private PartitionSpec.Builder copyToBuilder(Schema schema) { transform = Transforms.fromString(field.transform.toString()); } if (field.partitionId != null) { - builder.add(field.sourceId, field.partitionId, field.name, transform); + builder.add(field.sourceIds, field.partitionId, field.name, transform); } else { - builder.add(field.sourceId, field.name, transform); + builder.add(field.sourceIds, field.name, transform); } } @@ -88,13 +89,22 @@ Builder withSpecId(int newSpecId) { return this; } - Builder addField(String transformAsString, int sourceId, int partitionId, String name) { - fields.add(new UnboundPartitionField(transformAsString, sourceId, partitionId, name)); + Builder addField( + String transformAsString, int sourceId, int[] sourceIds, int partitionId, String name) { + if (sourceIds.length == 1) { + fields.add(new UnboundPartitionField(transformAsString, sourceId, partitionId, name)); + } else { + fields.add(new UnboundPartitionField(transformAsString, sourceIds, partitionId, name)); + } return this; } - Builder addField(String transformAsString, int sourceId, String name) { - fields.add(new UnboundPartitionField(transformAsString, sourceId, null, name)); + Builder addField(String transformAsString, int sourceId, int[] sourceIds, String name) { + if (sourceIds.length == 1) { + fields.add(new UnboundPartitionField(transformAsString, sourceId, null, name)); + } else { + fields.add(new UnboundPartitionField(transformAsString, sourceIds, null, name)); + } return this; } @@ -106,6 +116,7 @@ UnboundPartitionSpec build() { static class UnboundPartitionField { private final Transform transform; private final int sourceId; + private final int[] sourceIds; private final Integer partitionId; private final String name; @@ -121,6 +132,10 @@ public int sourceId() { return sourceId; } + public int[] sourceIds() { + return sourceIds; + } + public Integer partitionId() { return partitionId; } @@ -133,6 +148,19 @@ private UnboundPartitionField( String transformAsString, int sourceId, Integer partitionId, String name) { this.transform = Transforms.fromString(transformAsString); this.sourceId = sourceId; + this.sourceIds = new int[] {sourceId}; + this.partitionId = partitionId; + this.name = name; + } + + private UnboundPartitionField( + String trasformAsString, int[] sourceIds, Integer partitionId, String name) { + Preconditions.checkArgument( + sourceIds != null && sourceIds.length >= 1, + "sourceId should be the first in the sourceIds"); + this.transform = Transforms.fromString(trasformAsString); + this.sourceId = sourceIds.length > 1 ? -1 : sourceIds[0]; + this.sourceIds = sourceIds; this.partitionId = partitionId; this.name = name; } diff --git a/api/src/main/java/org/apache/iceberg/UnboundSortOrder.java b/api/src/main/java/org/apache/iceberg/UnboundSortOrder.java index ce9f6b1d6b2c..452fdd7f9c30 100644 --- a/api/src/main/java/org/apache/iceberg/UnboundSortOrder.java +++ b/api/src/main/java/org/apache/iceberg/UnboundSortOrder.java @@ -20,6 +20,7 @@ import java.util.Collections; import java.util.List; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.transforms.Transforms; @@ -98,8 +99,16 @@ Builder withOrderId(int newOrderId) { } Builder addSortField( - String transformAsString, int sourceId, SortDirection direction, NullOrder nullOrder) { - fields.add(new UnboundSortField(transformAsString, sourceId, direction, nullOrder)); + String transformAsString, + int sourceId, + int[] sourceIds, + SortDirection direction, + NullOrder nullOrder) { + if (sourceIds.length == 1) { + fields.add(new UnboundSortField(transformAsString, sourceId, direction, nullOrder)); + } else { + fields.add(new UnboundSortField(transformAsString, sourceIds, direction, nullOrder)); + } return this; } @@ -124,6 +133,7 @@ UnboundSortOrder build() { static class UnboundSortField { private final Transform transform; private final int sourceId; + private final int[] sourceIds; private final SortDirection direction; private final NullOrder nullOrder; @@ -131,6 +141,18 @@ private UnboundSortField( String transformAsString, int sourceId, SortDirection direction, NullOrder nullOrder) { this.transform = Transforms.fromString(transformAsString); this.sourceId = sourceId; + this.sourceIds = new int[] {sourceId}; + this.direction = direction; + this.nullOrder = nullOrder; + } + + private UnboundSortField( + String transformAsString, int[] sourceIds, SortDirection direction, NullOrder nullOrder) { + Preconditions.checkArgument( + sourceIds != null && sourceIds.length >= 1, "at least one source id should be provided"); + this.transform = Transforms.fromString(transformAsString); + this.sourceId = sourceIds.length > 1 ? -1 : sourceIds[0]; + this.sourceIds = new int[] {sourceId}; this.direction = direction; this.nullOrder = nullOrder; } @@ -143,6 +165,10 @@ public int sourceId() { return sourceId; } + public int[] sourceIds() { + return sourceIds; + } + public SortDirection direction() { return direction; } diff --git a/api/src/main/java/org/apache/iceberg/expressions/Bound.java b/api/src/main/java/org/apache/iceberg/expressions/Bound.java index e2434fbf5a79..3f15d6f32528 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Bound.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Bound.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.expressions; +import java.util.Collections; +import java.util.List; import org.apache.iceberg.StructLike; /** @@ -29,6 +31,11 @@ public interface Bound { /** Returns the underlying reference. */ BoundReference ref(); + /** Returns all the underlying references */ + default List> refs() { + return Collections.singletonList(ref()); + } + /** * Produce a value from the struct for this expression. * diff --git a/api/src/main/java/org/apache/iceberg/expressions/BoundTransform.java b/api/src/main/java/org/apache/iceberg/expressions/BoundTransform.java index 22271aaed9d5..353c2d3515d8 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/BoundTransform.java +++ b/api/src/main/java/org/apache/iceberg/expressions/BoundTransform.java @@ -18,10 +18,18 @@ */ package org.apache.iceberg.expressions; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; import org.apache.iceberg.StructLike; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; import org.apache.iceberg.util.SerializableFunction; +import org.apache.iceberg.util.StructProjection; /** * A transform expression. @@ -30,24 +38,65 @@ * @param the Java type of values returned by the function. */ public class BoundTransform implements BoundTerm { - private final BoundReference ref; + private final BoundReference[] refs; + private final StructProjection projection; + private Type inputType; private final Transform transform; private final SerializableFunction func; BoundTransform(BoundReference ref, Transform transform) { - this.ref = ref; + this.refs = Collections.singletonList(ref).toArray(new BoundReference[0]); this.transform = transform; + this.projection = null; this.func = transform.bind(ref.type()); } + BoundTransform( + List> refs, StructProjection projection, Transform transform) { + Preconditions.checkArgument( + refs != null && refs.size() >= 1, "At least one reference should be provided"); + if (refs.size() == 1) { + Preconditions.checkArgument( + projection == null, "For singe arg transform, projection should be null"); + } + this.refs = refs.toArray(new BoundReference[0]); + this.transform = transform; + this.projection = projection; + this.func = transform.bind(lazyInputType()); + } + + private Type lazyInputType() { + if (inputType == null) { + if (this.refs.length == 1) { + inputType = refs[0].type(); + } else { + List fields = + Arrays.stream(refs).map(BoundReference::field).collect(Collectors.toList()); + inputType = Types.StructType.of(fields); + } + } + return inputType; + } + @Override + @SuppressWarnings("unchecked") public T eval(StructLike struct) { - return func.apply(ref.eval(struct)); + if (projection == null) { + return func.apply(ref().eval(struct)); + } else { + return func.apply((S) projection.wrap(struct)); + } } @Override + @SuppressWarnings("unchecked") public BoundReference ref() { - return ref; + return (BoundReference) refs[0]; + } + + @Override + public List> refs() { + return Arrays.asList(this.refs); } public Transform transform() { @@ -56,16 +105,30 @@ public Transform transform() { @Override public Type type() { - return transform.getResultType(ref.type()); + return transform.getResultType(lazyInputType()); + } + + private boolean areRefsEquivalent(List> left, List> right) { + if (left.size() != right.size()) { + return false; + } + Iterator> leftIter = left.iterator(); + Iterator> rightIter = right.iterator(); + while (leftIter.hasNext() && rightIter.hasNext()) { + if (!leftIter.next().isEquivalentTo(rightIter.next())) { + return false; + } + } + return !leftIter.hasNext() && !rightIter.hasNext(); } @Override public boolean isEquivalentTo(BoundTerm other) { if (other instanceof BoundTransform) { BoundTransform bound = (BoundTransform) other; - return ref.isEquivalentTo(bound.ref()) && transform.equals(bound.transform()); + return areRefsEquivalent(this.refs(), other.refs()) && transform.equals(bound.transform()); } else if (transform.isIdentity() && other instanceof BoundReference) { - return ref.isEquivalentTo(other); + return refs.length == 1 && ref().isEquivalentTo(other); } return false; @@ -73,6 +136,10 @@ public boolean isEquivalentTo(BoundTerm other) { @Override public String toString() { - return transform + "(" + ref + ")"; + if (refs.length == 1) { + return transform + "(" + ref() + ")"; + } else { + return transform + "(" + Arrays.toString(refs) + ")"; + } } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java index f21a7705968b..beba5ea20664 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Expressions.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Expressions.java @@ -18,6 +18,9 @@ */ package org.apache.iceberg.expressions; +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; import java.util.stream.Stream; import org.apache.iceberg.expressions.Expression.Operation; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; @@ -77,6 +80,18 @@ public static UnboundTerm bucket(String name, int numBuckets) { return new UnboundTransform<>(ref(name), transform); } + public static UnboundTerm bucket(int numBuckets, String... names) { + if (names.length == 1) { + return bucket(names[0], numBuckets); + } + Transform transform = (Transform) Transforms.bucket(numBuckets); + NamedReference[] references = new NamedReference[names.length]; + for (int i = 0; i < names.length; i++) { + references[i] = ref(names[i]); + } + return new UnboundTransform<>(Arrays.asList(references), transform); + } + @SuppressWarnings("unchecked") public static UnboundTerm year(String name) { return new UnboundTransform<>(ref(name), (Transform) Transforms.year()); @@ -309,6 +324,12 @@ public static UnboundTerm transform(String name, Transform transfor return new UnboundTransform<>(ref(name), transform); } + public static UnboundTerm transform(List names, Transform transform) { + List> refs = + names.stream().map(Expressions::ref).collect(Collectors.toList()); + return new UnboundTransform<>(refs, transform); + } + public static UnboundAggregate count(String name) { return new UnboundAggregate<>(Operation.COUNT, ref(name)); } diff --git a/api/src/main/java/org/apache/iceberg/expressions/Unbound.java b/api/src/main/java/org/apache/iceberg/expressions/Unbound.java index 557ac3fd26be..678d7bea3a48 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/Unbound.java +++ b/api/src/main/java/org/apache/iceberg/expressions/Unbound.java @@ -18,6 +18,8 @@ */ package org.apache.iceberg.expressions; +import java.util.Collections; +import java.util.List; import org.apache.iceberg.types.Types; /** @@ -39,4 +41,8 @@ public interface Unbound { /** Returns this expression's underlying reference. */ NamedReference ref(); + + default List> refs() { + return Collections.singletonList(ref()); + } } diff --git a/api/src/main/java/org/apache/iceberg/expressions/UnboundTransform.java b/api/src/main/java/org/apache/iceberg/expressions/UnboundTransform.java index cae84733c8d5..1696c96dee60 100644 --- a/api/src/main/java/org/apache/iceberg/expressions/UnboundTransform.java +++ b/api/src/main/java/org/apache/iceberg/expressions/UnboundTransform.java @@ -18,22 +18,42 @@ */ package org.apache.iceberg.expressions; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; import org.apache.iceberg.exceptions.ValidationException; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.transforms.Transform; import org.apache.iceberg.types.Types; +import org.apache.iceberg.types.Types.StructType; +import org.apache.iceberg.util.StructProjection; public class UnboundTransform implements UnboundTerm, Term { - private final NamedReference ref; + private final NamedReference[] refs; private final Transform transform; UnboundTransform(NamedReference ref, Transform transform) { - this.ref = ref; + this.refs = Collections.singletonList(ref).toArray(new NamedReference[0]); this.transform = transform; } + UnboundTransform(List> refs, Transform transform) { + Preconditions.checkArgument( + refs != null && refs.size() >= 1, "At least one reference should be provided"); + this.refs = refs.toArray(new NamedReference[0]); + this.transform = transform; + } + + @SuppressWarnings("unchecked") @Override public NamedReference ref() { - return ref; + return (NamedReference) refs[0]; + } + + @Override + public List> refs() { + return Arrays.asList(refs); } public Transform transform() { @@ -42,26 +62,61 @@ public Transform transform() { @Override public BoundTransform bind(Types.StructType struct, boolean caseSensitive) { - BoundReference boundRef = ref.bind(struct, caseSensitive); + if (refs.length == 1) { + BoundReference boundRef = ref().bind(struct, caseSensitive); + return bindSingleRef(boundRef); + } else { + List> boundRefs = + Arrays.stream(refs).map(x -> x.bind(struct, caseSensitive)).collect(Collectors.toList()); + return bindRefs(struct, boundRefs); + } + } + private BoundTransform bindSingleRef(BoundReference boundRef) { try { ValidationException.check( transform.canTransform(boundRef.type()), "Cannot bind: %s cannot transform %s values from '%s'", transform, boundRef.type(), - ref.name()); + ref().name()); } catch (IllegalArgumentException e) { throw new ValidationException( "Cannot bind: %s cannot transform %s values from '%s'", - transform, boundRef.type(), ref.name()); + transform, boundRef.type(), ref().name()); } return new BoundTransform<>(boundRef, transform); } + private BoundTransform bindRefs( + Types.StructType structType, List> boundedRefs) { + StructType projectedType = + StructType.of(boundedRefs.stream().map(BoundReference::field).collect(Collectors.toList())); + StructProjection projection = StructProjection.create(structType, projectedType); + String refNames = + boundedRefs.stream().map(BoundReference::name).collect(Collectors.joining("(", ",", ")")); + try { + ValidationException.check( + transform.canTransform(projectedType), + "Cannot bind: %s cannot transform %s values from '%s'", + transform, + projectedType, + refNames); + } catch (IllegalArgumentException e) { + throw new ValidationException( + "Cannot bind: %s cannot transform %s values from '%s'", + transform, projectedType, refNames); + } + return new BoundTransform<>(boundedRefs, projection, transform); + } + @Override public String toString() { - return transform + "(" + ref + ")"; + if (refs.length == 1) { + return transform + "(" + ref() + ")"; + } else { + return transform + "(" + Arrays.toString(refs) + ")"; + } } } diff --git a/api/src/main/java/org/apache/iceberg/transforms/Bucket.java b/api/src/main/java/org/apache/iceberg/transforms/Bucket.java index 912bcd271725..e2c0d8b06555 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/Bucket.java +++ b/api/src/main/java/org/apache/iceberg/transforms/Bucket.java @@ -21,8 +21,10 @@ import java.io.Serializable; import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.UUID; import java.util.function.Function; +import org.apache.iceberg.StructLike; import org.apache.iceberg.expressions.BoundPredicate; import org.apache.iceberg.expressions.BoundTransform; import org.apache.iceberg.expressions.Expression; @@ -30,6 +32,7 @@ import org.apache.iceberg.expressions.UnboundPredicate; import org.apache.iceberg.relocated.com.google.common.base.Objects; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.hash.Hasher; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.iceberg.util.BucketUtil; @@ -65,6 +68,8 @@ static & SerializableFunction> B get( return (B) new BucketByteBuffer(numBuckets); case UUID: return (B) new BucketUUID(numBuckets); + case STRUCT: + return (B) new BucketStructLike(numBuckets, type.asStructType()); default: throw new IllegalArgumentException("Cannot bucket by type: " + type); } @@ -101,6 +106,21 @@ public Integer apply(T value) { @Override public boolean canTransform(Type type) { + if (type.isPrimitiveType()) { + return canTransformPrimitive(type); + } else { + switch (type.typeId()) { + case STRUCT: + // only struct of primitives fields can be bucketed. + return type.asStructType().fields().stream() + .map(Types.NestedField::type) + .allMatch(this::canTransform); + } + } + return false; + } + + private boolean canTransformPrimitive(Type type) { switch (type.typeId()) { case INTEGER: case LONG: @@ -265,4 +285,58 @@ protected int hash(BigDecimal value) { return BucketUtil.hash(value); } } + + private static class BucketStructLike extends Bucket + implements SerializableFunction { + + private final Types.StructType type; + + private BucketStructLike(int numBuckets, Types.StructType type) { + super(numBuckets); + this.type = type; + } + + @Override + protected int hash(StructLike value) { + Hasher hasher = BucketUtil.hasher(); + boolean isNull = true; + for (int i = 0; i < value.size(); i += 1) { + Type fieldType = type.fields().get(i).type(); + Object val = value.get(i, fieldType.typeId().javaClass()); + if (val == null) { + continue; + } else { + isNull = false; + } + switch (fieldType.typeId()) { + case INTEGER: + case DATE: + hasher.putLong((long) ((int) val)); + break; + case LONG: + case TIME: + case TIMESTAMP: + hasher.putLong((long) val); + break; + case DECIMAL: + hasher.putBytes(((BigDecimal) val).unscaledValue().toByteArray()); + break; + case STRING: + hasher.putString((CharSequence) val, StandardCharsets.UTF_8); + break; + case FIXED: + case BINARY: + hasher.putBytes((ByteBuffer) val); + break; + case UUID: + UUID uuid = (UUID) val; + hasher.putLong(Long.reverseBytes(uuid.getMostSignificantBits())); + hasher.putLong(Long.reverseBytes(uuid.getLeastSignificantBits())); + break; + } + } + Preconditions.checkArgument(!isNull, "All fields are null"); + return hasher.hash().asInt(); + } + } } diff --git a/api/src/main/java/org/apache/iceberg/transforms/PartitionSpecVisitor.java b/api/src/main/java/org/apache/iceberg/transforms/PartitionSpecVisitor.java index e4796478bf28..ffb1387264ec 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/PartitionSpecVisitor.java +++ b/api/src/main/java/org/apache/iceberg/transforms/PartitionSpecVisitor.java @@ -18,6 +18,7 @@ */ package org.apache.iceberg.transforms; +import java.util.Arrays; import java.util.List; import org.apache.iceberg.PartitionField; import org.apache.iceberg.PartitionSpec; @@ -37,10 +38,18 @@ default T bucket(int fieldId, String sourceName, int sourceId, int numBuckets) { return bucket(sourceName, sourceId, numBuckets); } + default T bucket(int fieldId, String[] sourceNames, int[] sourceIds, int numBuckets) { + return bucket(sourceNames, sourceIds, numBuckets); + } + default T bucket(String sourceName, int sourceId, int numBuckets) { throw new UnsupportedOperationException("Bucket transform is not supported"); } + default T bucket(String[] sourceNames, int[] sourceIds, int numBuckets) { + throw new UnsupportedOperationException("Bucket transform is not supported"); + } + default T truncate(int fieldId, String sourceName, int sourceId, int width) { return truncate(sourceName, sourceId, width); } @@ -117,7 +126,15 @@ static R visit(Schema schema, PartitionField field, PartitionSpecVisitor return visitor.identity(field.fieldId(), sourceName, field.sourceId()); } else if (transform instanceof Bucket) { int numBuckets = ((Bucket) transform).numBuckets(); - return visitor.bucket(field.fieldId(), sourceName, field.sourceId(), numBuckets); + if (field.sourceIds().length == 1) { + return visitor.bucket(field.fieldId(), sourceName, field.sourceId(), numBuckets); + } else { + String[] sourceColumnNames = + Arrays.stream(field.sourceIds()) + .mapToObj(schema::findColumnName) + .toArray(String[]::new); + return visitor.bucket(field.fieldId(), sourceColumnNames, field.sourceIds(), numBuckets); + } } else if (transform instanceof Truncate) { int width = ((Truncate) transform).width(); return visitor.truncate(field.fieldId(), sourceName, field.sourceId(), width); diff --git a/api/src/main/java/org/apache/iceberg/transforms/SortOrderVisitor.java b/api/src/main/java/org/apache/iceberg/transforms/SortOrderVisitor.java index 680e095270fb..c371fb565be3 100644 --- a/api/src/main/java/org/apache/iceberg/transforms/SortOrderVisitor.java +++ b/api/src/main/java/org/apache/iceberg/transforms/SortOrderVisitor.java @@ -18,12 +18,14 @@ */ package org.apache.iceberg.transforms; +import java.util.Arrays; import java.util.List; import org.apache.iceberg.NullOrder; import org.apache.iceberg.Schema; import org.apache.iceberg.SortDirection; import org.apache.iceberg.SortField; import org.apache.iceberg.SortOrder; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; public interface SortOrderVisitor { @@ -33,6 +35,26 @@ public interface SortOrderVisitor { T bucket( String sourceName, int sourceId, int width, SortDirection direction, NullOrder nullOrder); + default T bucket( + String[] sourceNames, + int[] sourceIds, + int width, + SortDirection direction, + NullOrder nullOrder) { + Preconditions.checkArgument( + sourceNames != null + && sourceNames.length >= 1 + && sourceIds != null + && sourceIds.length >= 1, + "At least one sourceName and sourceId should be provided"); + if (sourceNames.length == 1 && sourceIds.length == 1) { + return bucket(sourceNames[0], sourceIds[0], width, direction, nullOrder); + } else { + throw new UnsupportedOperationException( + String.format("bucket on multiple columns is not supported")); + } + } + T truncate( String sourceName, int sourceId, int width, SortDirection direction, NullOrder nullOrder); @@ -76,9 +98,12 @@ static List visit(SortOrder sortOrder, SortOrderVisitor visitor) { visitor.field(sourceName, field.sourceId(), field.direction(), field.nullOrder())); } else if (transform instanceof Bucket) { int numBuckets = ((Bucket) transform).numBuckets(); + int[] sourceIds = field.sourceIds(); + String[] sourceNames = + Arrays.stream(sourceIds).mapToObj(schema::findColumnName).toArray(String[]::new); results.add( visitor.bucket( - sourceName, field.sourceId(), numBuckets, field.direction(), field.nullOrder())); + sourceNames, sourceIds, numBuckets, field.direction(), field.nullOrder())); } else if (transform instanceof Truncate) { int width = ((Truncate) transform).width(); results.add( diff --git a/api/src/main/java/org/apache/iceberg/util/BucketUtil.java b/api/src/main/java/org/apache/iceberg/util/BucketUtil.java index 1f3a68aef52b..7b77d0ea6661 100644 --- a/api/src/main/java/org/apache/iceberg/util/BucketUtil.java +++ b/api/src/main/java/org/apache/iceberg/util/BucketUtil.java @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets; import java.util.UUID; import org.apache.iceberg.relocated.com.google.common.hash.HashFunction; +import org.apache.iceberg.relocated.com.google.common.hash.Hasher; import org.apache.iceberg.relocated.com.google.common.hash.Hashing; /** @@ -97,4 +98,8 @@ public static int hash(UUID value) { public static int hash(BigDecimal value) { return MURMUR3.hashBytes(value.unscaledValue().toByteArray()).asInt(); } + + public static Hasher hasher() { + return MURMUR3.newHasher(); + } } diff --git a/core/src/main/java/org/apache/iceberg/PartitionSpecParser.java b/core/src/main/java/org/apache/iceberg/PartitionSpecParser.java index a51b03c8f015..2831a170c49c 100644 --- a/core/src/main/java/org/apache/iceberg/PartitionSpecParser.java +++ b/core/src/main/java/org/apache/iceberg/PartitionSpecParser.java @@ -35,6 +35,7 @@ private PartitionSpecParser() {} private static final String SPEC_ID = "spec-id"; private static final String FIELDS = "fields"; private static final String SOURCE_ID = "source-id"; + private static final String SOURCE_IDS = "source-ids"; private static final String FIELD_ID = "field-id"; private static final String TRANSFORM = "transform"; private static final String NAME = "name"; @@ -99,6 +100,16 @@ static void toJsonFields(UnboundPartitionSpec spec, JsonGenerator generator) thr generator.writeStringField(NAME, field.name()); generator.writeStringField(TRANSFORM, field.transformAsString()); generator.writeNumberField(SOURCE_ID, field.sourceId()); + // only serialize multiple sourceIds + if (field.sourceIds().length > 1) { + // fieldName: SOURCE_IDS, array: sourceIds. + generator.writeFieldName(SOURCE_IDS); + generator.writeStartArray(); + for (int i : field.sourceIds()) { + generator.writeNumber(i); + } + generator.writeEndArray(); + } generator.writeNumberField(FIELD_ID, field.partitionId()); generator.writeEndObject(); } @@ -133,14 +144,17 @@ private static void buildFromJsonFields(UnboundPartitionSpec.Builder builder, Js String name = JsonUtil.getString(NAME, element); String transform = JsonUtil.getString(TRANSFORM, element); int sourceId = JsonUtil.getInt(SOURCE_ID, element); + int[] sourceIds = JsonUtil.getIntArrayOrNull(SOURCE_IDS, element); + // backward compatibility + sourceIds = sourceIds == null ? new int[] {sourceId} : sourceIds; // partition field ids are missing in old PartitionSpec, they always auto-increment from // PARTITION_DATA_ID_START if (element.has(FIELD_ID)) { - builder.addField(transform, sourceId, JsonUtil.getInt(FIELD_ID, element), name); + builder.addField(transform, sourceId, sourceIds, JsonUtil.getInt(FIELD_ID, element), name); fieldIdCount++; } else { - builder.addField(transform, sourceId, name); + builder.addField(transform, sourceId, sourceIds, name); } } diff --git a/core/src/main/java/org/apache/iceberg/SortOrderParser.java b/core/src/main/java/org/apache/iceberg/SortOrderParser.java index 31307cf9dc7f..db9ef90dfcea 100644 --- a/core/src/main/java/org/apache/iceberg/SortOrderParser.java +++ b/core/src/main/java/org/apache/iceberg/SortOrderParser.java @@ -36,6 +36,7 @@ public class SortOrderParser { private static final String NULL_ORDER = "null-order"; private static final String TRANSFORM = "transform"; private static final String SOURCE_ID = "source-id"; + private static final String SOURCE_IDS = "source-ids"; private SortOrderParser() {} @@ -70,6 +71,15 @@ private static void toJsonFields(SortOrder sortOrder, JsonGenerator generator) generator.writeStartObject(); generator.writeStringField(TRANSFORM, field.transform().toString()); generator.writeNumberField(SOURCE_ID, field.sourceId()); + if (field.sourceIds().length > 1) { + // fieldName: SOURCE_IDS, array: sourceIds. + generator.writeFieldName(SOURCE_IDS); + generator.writeStartArray(); + for (int i : field.sourceIds()) { + generator.writeNumber(i); + } + generator.writeEndArray(); + } generator.writeStringField(DIRECTION, toJson(field.direction())); generator.writeStringField(NULL_ORDER, toJson(field.nullOrder())); generator.writeEndObject(); @@ -152,6 +162,8 @@ private static void buildFromJsonFields(UnboundSortOrder.Builder builder, JsonNo String transform = JsonUtil.getString(TRANSFORM, element); int sourceId = JsonUtil.getInt(SOURCE_ID, element); + int[] sourceIds = JsonUtil.getIntArrayOrNull(SOURCE_IDS, element); + sourceIds = sourceIds == null ? new int[] {sourceId} : sourceIds; String directionAsString = JsonUtil.getString(DIRECTION, element); SortDirection direction = SortDirection.fromString(directionAsString); @@ -159,7 +171,7 @@ private static void buildFromJsonFields(UnboundSortOrder.Builder builder, JsonNo String nullOrderingAsString = JsonUtil.getString(NULL_ORDER, element); NullOrder nullOrder = toNullOrder(nullOrderingAsString); - builder.addSortField(transform, sourceId, direction, nullOrder); + builder.addSortField(transform, sourceId, sourceIds, direction, nullOrder); } } diff --git a/core/src/main/java/org/apache/iceberg/TableMetadata.java b/core/src/main/java/org/apache/iceberg/TableMetadata.java index 25af350d5e8e..c93ea4e3d67e 100644 --- a/core/src/main/java/org/apache/iceberg/TableMetadata.java +++ b/core/src/main/java/org/apache/iceberg/TableMetadata.java @@ -19,6 +19,7 @@ package org.apache.iceberg; import java.io.Serializable; +import java.util.Arrays; import java.util.Collection; import java.util.List; import java.util.Map; @@ -108,10 +109,22 @@ static TableMetadata newTableMetadata( PartitionSpec.Builder specBuilder = PartitionSpec.builderFor(freshSchema).withSpecId(INITIAL_SPEC_ID); for (PartitionField field : spec.fields()) { - // look up the name of the source field in the old schema to get the new schema's id - String sourceName = schema.findColumnName(field.sourceId()); - // reassign all partition fields with fresh partition field Ids to ensure consistency - specBuilder.add(freshSchema.findField(sourceName).fieldId(), field.name(), field.transform()); + if (field.sourceIds().length == 1) { + // look up the name of the source field in the old schema to get the new schema's id + String sourceName = schema.findColumnName(field.sourceId()); + // reassign all partition fields with fresh partition field Ids to ensure consistency + specBuilder.add( + freshSchema.findField(sourceName).fieldId(), field.name(), field.transform()); + } else { + int[] originalSourceIds = field.sourceIds(); + int[] sourceIds = + Arrays.stream(originalSourceIds) + .mapToObj(schema::findColumnName) + .map(x -> freshSchema.findField(x).fieldId()) + .mapToInt(x -> x) + .toArray(); + specBuilder.add(sourceIds, field.name(), field.transform()); + } } PartitionSpec freshSpec = specBuilder.build(); @@ -731,18 +744,33 @@ private static PartitionSpec freshSpec(int specId, Schema schema, PartitionSpec UnboundPartitionSpec.Builder specBuilder = UnboundPartitionSpec.builder().withSpecId(specId); for (PartitionField field : partitionSpec.fields()) { - // look up the name of the source field in the old schema to get the new schema's id - String sourceName = partitionSpec.schema().findColumnName(field.sourceId()); + if (field.sourceIds().length == 1) { // single arg transform + // look up the name of the source field in the old schema to get the new schema's id + String sourceName = partitionSpec.schema().findColumnName(field.sourceId()); - final int fieldId; - if (sourceName != null) { - fieldId = schema.findField(sourceName).fieldId(); - } else { - // In the case of a null sourceName, the column has been deleted. - // This only happens in V1 tables where the reference is still around as a void transform - fieldId = field.sourceId(); + final int sourceId; + if (sourceName != null) { + sourceId = schema.findField(sourceName).fieldId(); + } else { + // In the case of a null sourceName, the column has been deleted. + // This only happens in V1 tables where the reference is still around as a void transform + sourceId = field.sourceId(); + } + final int[] sourceIds = new int[] {sourceId}; + // todo: handle spec evolution + specBuilder.addField( + field.transform().toString(), sourceId, sourceIds, field.fieldId(), field.name()); + } else { // multi-args transform + int[] originalSourceIds = field.sourceIds(); + int[] sourceIds = + Arrays.stream(originalSourceIds) + .mapToObj(x -> partitionSpec.schema().findColumnName(x)) + .map(x -> schema.findField(x).fieldId()) + .mapToInt(x -> x) + .toArray(); + specBuilder.addField( + field.transform().toString(), -1, sourceIds, field.fieldId(), field.name()); } - specBuilder.addField(field.transform().toString(), fieldId, field.fieldId(), field.name()); } return specBuilder.build().bind(schema); @@ -756,12 +784,22 @@ private static SortOrder freshSortOrder(int orderId, Schema schema, SortOrder so } for (SortField field : sortOrder.fields()) { - // look up the name of the source field in the old schema to get the new schema's id - String sourceName = sortOrder.schema().findColumnName(field.sourceId()); - // reassign all sort fields with fresh sort field IDs - int newSourceId = schema.findField(sourceName).fieldId(); + int[] newSourceIds = new int[field.sourceIds().length]; + int idx = 0; + for (int sourceId : field.sourceIds()) { + // look up the name of the source field in the old schema to get the new schema's id + String sourceName = sortOrder.schema().findColumnName(sourceId); + // reassign all sort fields with fresh sort field IDs + int newSourceId = schema.findField(sourceName).fieldId(); + newSourceIds[idx++] = newSourceId; + } + int newSourceId = newSourceIds.length > 1 ? -1 : newSourceIds[0]; builder.addSortField( - field.transform().toString(), newSourceId, field.direction(), field.nullOrder()); + field.transform().toString(), + newSourceId, + newSourceIds, + field.direction(), + field.nullOrder()); } return builder.build().bind(schema); diff --git a/core/src/main/java/org/apache/iceberg/util/CopySortOrderFields.java b/core/src/main/java/org/apache/iceberg/util/CopySortOrderFields.java index 433f30f81386..7dd1998a4458 100644 --- a/core/src/main/java/org/apache/iceberg/util/CopySortOrderFields.java +++ b/core/src/main/java/org/apache/iceberg/util/CopySortOrderFields.java @@ -48,6 +48,17 @@ public Void bucket( return null; } + @Override + public Void bucket( + String[] sourceNames, + int[] sourceIds, + int numBuckets, + SortDirection direction, + NullOrder nullOrder) { + builder.sortBy(Expressions.bucket(numBuckets, sourceNames), direction, nullOrder); + return null; + } + @Override public Void truncate( String sourceName, int sourceId, int width, SortDirection direction, NullOrder nullOrder) { diff --git a/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java b/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java index 37e0c1fffab0..f5f1acb1b9bb 100644 --- a/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java +++ b/core/src/main/java/org/apache/iceberg/util/SortOrderUtil.java @@ -18,7 +18,9 @@ */ package org.apache.iceberg.util; +import java.util.Arrays; import java.util.Collections; +import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -96,8 +98,16 @@ public static SortOrder buildSortOrder(Schema schema, PartitionSpec spec, SortOr // build a sort prefix of partition fields that are not already in the sort order's prefix SortOrder.Builder builder = SortOrder.builderFor(schema); for (PartitionField field : requiredClusteringFields.values()) { - String sourceName = schema.findColumnName(field.sourceId()); - builder.asc(Expressions.transform(sourceName, field.transform())); + if (field.sourceIds().length == 1) { + String sourceName = schema.findColumnName(field.sourceId()); + builder.asc(Expressions.transform(sourceName, field.transform())); + } else { + List sourceNames = + Arrays.stream(field.sourceIds()) + .mapToObj(schema::findColumnName) + .collect(Collectors.toList()); + builder.asc(Expressions.transform(sourceNames, field.transform())); + } } // add the configured sort to the partition spec prefix sort diff --git a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java index b1ba53455123..9684358878fe 100644 --- a/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java +++ b/spark/v3.3/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestIcebergExpressions.java @@ -102,4 +102,37 @@ public void testBucketExpressions() { ImmutableList.of(row(0, 2, 0, 4, 1)), sql("SELECT int_c, long_c, dec_c, str_c, binary_c FROM v")); } + + @Test + public void testMultiArgBucketExpressions() { + sql( + "CREATE TABLE %s ( " + + " int_c INT, long_c LONG, dec_c DECIMAL(4, 2), str_c STRING, binary_c BINARY " + + ") USING iceberg" + + " PARTITIONED BY (bucket(4, int_c, long_c))", + tableName); + sql( + "CREATE TEMPORARY VIEW emp " + + "AS SELECT * FROM VALUES (101, 10001, 10.65, '101-Employee', CAST('1234' AS BINARY)) " + + "AS EMP(int_c, long_c, dec_c, str_c, binary_c)"); + + sql("INSERT INTO %s SELECT * FROM emp", tableName); + spark.table(tableName).show(); + spark + .sql(String.format("select file_path, partition, sort_order_id from %s.files", tableName)) + .show(false); + Dataset df = spark.sql("SELECT * FROM " + tableName); + df.select( + new Column(new IcebergBucketTransform(2, df.col("int_c").expr())).as("int_c"), + new Column(new IcebergBucketTransform(3, df.col("long_c").expr())).as("long_c"), + new Column(new IcebergBucketTransform(4, df.col("dec_c").expr())).as("dec_c"), + new Column(new IcebergBucketTransform(5, df.col("str_c").expr())).as("str_c"), + new Column(new IcebergBucketTransform(6, df.col("binary_c").expr())).as("binary_c")) + .createOrReplaceTempView("v"); + + assertEquals( + "Should have expected rows", + ImmutableList.of(row(0, 2, 0, 4, 1)), + sql("SELECT int_c, long_c, dec_c, str_c, binary_c FROM v")); + } } diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java index 52d68db2e4f9..3b70029d6dcd 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -48,6 +48,17 @@ public SortOrder bucket( Expressions.bucket(width, quotedName(id)), toSpark(direction), toSpark(nullOrder)); } + @Override + public SortOrder bucket( + String[] sourceNames, int[] ids, int width, SortDirection direction, NullOrder nullOrder) { + String[] quotedNames = new String[ids.length]; + for (int i = 0; i < ids.length; i++) { + quotedNames[i] = quotedName(ids[i]); + } + return Expressions.sort( + Expressions.bucket(width, quotedNames), toSpark(direction), toSpark(nullOrder)); + } + @Override public SortOrder truncate( String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java index 23a53ea9e8c3..fd6f0f6bc1d8 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -96,6 +96,7 @@ public class Spark3Util { private static final Set RESERVED_PROPERTIES = ImmutableSet.of(TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); private static final Joiner DOT = Joiner.on("."); + private static final Set MULTI_ARGS_TRANSFORMS = ImmutableSet.of("zorder", "bucket"); private Spark3Util() {} @@ -300,6 +301,15 @@ public Transform bucket(String sourceName, int sourceId, int numBuckets) { return Expressions.bucket(numBuckets, quotedName(sourceId)); } + @Override + public Transform bucket(String[] sourceNames, int[] sourceIds, int numBuckets) { + String[] quotedNames = new String[sourceIds.length]; + for (int i = 0; i < sourceIds.length; i++) { + quotedNames[i] = quotedName(sourceIds[i]); + } + return Expressions.bucket(numBuckets, quotedNames); + } + @Override public Transform truncate(String sourceName, int sourceId, int width) { NamedReference column = Expressions.column(quotedName(sourceId)); @@ -350,7 +360,7 @@ public static Term toIcebergTerm(Expression expr) { if (expr instanceof Transform) { Transform transform = (Transform) expr; Preconditions.checkArgument( - "zorder".equals(transform.name()) || transform.references().length == 1, + MULTI_ARGS_TRANSFORMS.contains(transform.name()) || transform.references().length == 1, "Cannot convert transform with more than one column reference: %s", transform); String colName = DOT.join(transform.references()[0].fieldNames()); @@ -358,7 +368,11 @@ public static Term toIcebergTerm(Expression expr) { case "identity": return org.apache.iceberg.expressions.Expressions.ref(colName); case "bucket": - return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); + String[] cols = + Stream.of(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .toArray(String[]::new); + return org.apache.iceberg.expressions.Expressions.bucket(findWidth(transform), cols); case "years": return org.apache.iceberg.expressions.Expressions.year(colName); case "months": @@ -405,7 +419,7 @@ public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partition PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); for (Transform transform : partitioning) { Preconditions.checkArgument( - transform.references().length == 1, + MULTI_ARGS_TRANSFORMS.contains(transform.name()) || transform.references().length == 1, "Cannot convert transform with more than one column reference: %s", transform); String colName = DOT.join(transform.references()[0].fieldNames()); @@ -414,7 +428,11 @@ public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partition builder.identity(colName); break; case "bucket": - builder.bucket(colName, findWidth(transform)); + String[] colNames = + Arrays.stream(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .toArray(String[]::new); + builder.bucket(colNames, findWidth(transform)); break; case "years": builder.year(colName); @@ -961,6 +979,18 @@ public String bucket( return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); } + @Override + public String bucket( + String[] sourceNames, + int[] sourceIds, + int numBuckets, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + String sourceNameList = String.join(", ", sourceNames); + return String.format( + "bucket(%s, %s) %s %s", numBuckets, sourceNameList, direction, nullOrder); + } + @Override public String truncate( String sourceName, diff --git a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index dd493fbc5097..e0ded1cec11d 100644 --- a/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v3.3/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -100,7 +100,9 @@ public NamedReference[] filterAttributes() { for (PartitionSpec spec : specs()) { for (PartitionField field : spec.fields()) { - partitionFieldSourceIds.add(field.sourceId()); + for (int sourceId : field.sourceIds()) { + partitionFieldSourceIds.add(sourceId); + } } } diff --git a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java index fcdf9bf992bb..8635306786b3 100644 --- a/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java +++ b/spark/v3.4/spark-extensions/src/test/java/org/apache/iceberg/spark/extensions/TestRequiredDistributionAndOrdering.java @@ -32,6 +32,15 @@ import org.junit.Test; public class TestRequiredDistributionAndOrdering extends SparkExtensionsTestBase { + private static final List simpleData = + ImmutableList.of( + new ThreeColumnRecord(1, null, "A"), + new ThreeColumnRecord(2, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(3, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(4, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(5, "BBBBBBBBBB", "A"), + new ThreeColumnRecord(6, "BBBBBBBBBB", "B"), + new ThreeColumnRecord(7, "BBBBBBBBBB", "A")); public TestRequiredDistributionAndOrdering( String catalogName, String implementation, Map config) { @@ -72,6 +81,31 @@ public void testDefaultLocalSortWithBucketTransforms() throws NoSuchTableExcepti sql("SELECT count(*) FROM %s", tableName)); } + @Test + public void testDefaultLocalSortWithMultiColumnBucketTransform() throws NoSuchTableException { + + sql( + "CREATE TABLE %s (c1 INT, c2 STRING, c3 STRING) " + + "USING iceberg " + + "PARTITIONED BY (bucket(4, c1, c2))", + tableName); + + Dataset ds = spark.createDataFrame(simpleData, ThreeColumnRecord.class); + + // sort cols doesn't matter as it will be replaced by sort order inferred from bucket transform + Dataset inputDF = ds.coalesce(1).sortWithinPartitions("c3", "c2"); + + // should insert a local sort by partition columns by default + inputDF.writeTo(tableName).append(); + + assertEquals( + "Row count must match", + ImmutableList.of(row(7L)), + sql("SELECT count(*) FROM %s", tableName)); + + // spark.sql(String.format("select * from %s.files", tableName)).show(false); + } + @Test public void testPartitionColumnsArePrependedForRangeDistribution() throws NoSuchTableException { sql( diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java index 781f61b33f0e..cf8465e379b5 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/SortOrderToSpark.java @@ -48,6 +48,17 @@ public SortOrder bucket( Expressions.bucket(width, quotedName(id)), toSpark(direction), toSpark(nullOrder)); } + @Override + public SortOrder bucket( + String[] sourceNames, int[] ids, int width, SortDirection direction, NullOrder nullOrder) { + String[] quotedNames = new String[ids.length]; + for (int i = 0; i < ids.length; i++) { + quotedNames[i] = quotedName(ids[i]); + } + return Expressions.sort( + Expressions.bucket(width, quotedNames), toSpark(direction), toSpark(nullOrder)); + } + @Override public SortOrder truncate( String sourceName, int id, int width, SortDirection direction, NullOrder nullOrder) { diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java index ad4e2d16b749..9643d4a21b67 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/Spark3Util.java @@ -97,6 +97,7 @@ public class Spark3Util { private static final Set RESERVED_PROPERTIES = ImmutableSet.of(TableCatalog.PROP_LOCATION, TableCatalog.PROP_PROVIDER); private static final Joiner DOT = Joiner.on("."); + private static final Set MULTI_ARGS_TRANSFORMS = ImmutableSet.of("zorder", "bucket"); private Spark3Util() {} @@ -307,6 +308,15 @@ public Transform bucket(String sourceName, int sourceId, int numBuckets) { return Expressions.bucket(numBuckets, quotedName(sourceId)); } + @Override + public Transform bucket(String[] sourceNames, int[] sourceIds, int numBuckets) { + String[] quotedNames = new String[sourceIds.length]; + for (int i = 0; i < sourceIds.length; i++) { + quotedNames[i] = quotedName(sourceIds[i]); + } + return Expressions.bucket(numBuckets, quotedNames); + } + @Override public Transform truncate(String sourceName, int sourceId, int width) { NamedReference column = Expressions.column(quotedName(sourceId)); @@ -357,7 +367,7 @@ public static Term toIcebergTerm(Expression expr) { if (expr instanceof Transform) { Transform transform = (Transform) expr; Preconditions.checkArgument( - "zorder".equals(transform.name()) || transform.references().length == 1, + MULTI_ARGS_TRANSFORMS.contains(transform.name()) || transform.references().length == 1, "Cannot convert transform with more than one column reference: %s", transform); String colName = DOT.join(transform.references()[0].fieldNames()); @@ -365,7 +375,11 @@ public static Term toIcebergTerm(Expression expr) { case "identity": return org.apache.iceberg.expressions.Expressions.ref(colName); case "bucket": - return org.apache.iceberg.expressions.Expressions.bucket(colName, findWidth(transform)); + String[] cols = + Stream.of(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .toArray(String[]::new); + return org.apache.iceberg.expressions.Expressions.bucket(findWidth(transform), cols); case "years": return org.apache.iceberg.expressions.Expressions.year(colName); case "months": @@ -412,7 +426,7 @@ public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partition PartitionSpec.Builder builder = PartitionSpec.builderFor(schema); for (Transform transform : partitioning) { Preconditions.checkArgument( - transform.references().length == 1, + MULTI_ARGS_TRANSFORMS.contains(transform.name()) || transform.references().length == 1, "Cannot convert transform with more than one column reference: %s", transform); String colName = DOT.join(transform.references()[0].fieldNames()); @@ -421,7 +435,11 @@ public static PartitionSpec toPartitionSpec(Schema schema, Transform[] partition builder.identity(colName); break; case "bucket": - builder.bucket(colName, findWidth(transform)); + String[] colNames = + Arrays.stream(transform.references()) + .map(ref -> DOT.join(ref.fieldNames())) + .toArray(String[]::new); + builder.bucket(colNames, findWidth(transform)); break; case "years": builder.year(colName); @@ -951,6 +969,18 @@ public String bucket( return String.format("bucket(%s, %s) %s %s", numBuckets, sourceName, direction, nullOrder); } + @Override + public String bucket( + String[] sourceNames, + int[] sourceIds, + int numBuckets, + org.apache.iceberg.SortDirection direction, + NullOrder nullOrder) { + String sourceNameList = String.join(", ", sourceNames); + return String.format( + "bucket(%s, %s) %s %s", numBuckets, sourceNameList, direction, nullOrder); + } + @Override public String truncate( String sourceName, diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java index af3c67a4bb63..d122d02db239 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/functions/BucketFunction.java @@ -20,8 +20,10 @@ import java.math.BigDecimal; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Set; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; +import org.apache.iceberg.relocated.com.google.common.hash.Hasher; import org.apache.iceberg.util.BucketUtil; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.connector.catalog.functions.BoundFunction; @@ -61,51 +63,80 @@ public class BucketFunction implements UnboundFunction { private static final Set SUPPORTED_NUM_BUCKETS_TYPES = ImmutableSet.of(DataTypes.ByteType, DataTypes.ShortType, DataTypes.IntegerType); + // DecimalType should be handled separately because it doesn't have a single instance. + private static final Set SUPPORTED_PRIMITIVE_VALUE_TYPES = + ImmutableSet.of( + DataTypes.DateType, + DataTypes.ByteType, + DataTypes.ShortType, + DataTypes.IntegerType, + DataTypes.LongType, + DataTypes.TimestampType, + DataTypes.TimestampNTZType, + DataTypes.StringType, + DataTypes.BinaryType); + @Override @SuppressWarnings("checkstyle:CyclomaticComplexity") public BoundFunction bind(StructType inputType) { - if (inputType.size() != 2) { + if (inputType.size() < 2) { throw new UnsupportedOperationException( - "Wrong number of inputs (expected numBuckets and value)"); + "Wrong number of inputs (expected numBuckets and values)"); } StructField numBucketsField = inputType.fields()[NUM_BUCKETS_ORDINAL]; - StructField valueField = inputType.fields()[VALUE_ORDINAL]; + StructField firstValueField = inputType.fields()[VALUE_ORDINAL]; if (!SUPPORTED_NUM_BUCKETS_TYPES.contains(numBucketsField.dataType())) { throw new UnsupportedOperationException( "Expected number of buckets to be tinyint, shortint or int"); } - DataType type = valueField.dataType(); - if (type instanceof DateType) { - return new BucketInt(type); - } else if (type instanceof ByteType - || type instanceof ShortType - || type instanceof IntegerType) { - return new BucketInt(DataTypes.IntegerType); - } else if (type instanceof LongType) { - return new BucketLong(type); - } else if (type instanceof TimestampType) { - return new BucketLong(type); - } else if (type instanceof TimestampNTZType) { - return new BucketLong(type); - } else if (type instanceof DecimalType) { - return new BucketDecimal(type); - } else if (type instanceof StringType) { - return new BucketString(); - } else if (type instanceof BinaryType) { - return new BucketBinary(); + DataType type = firstValueField.dataType(); + if (inputType.size() == 2) { + if (type instanceof DateType) { + return new BucketInt(type); + } else if (type instanceof ByteType + || type instanceof ShortType + || type instanceof IntegerType) { + return new BucketInt(DataTypes.IntegerType); + } else if (type instanceof LongType) { + return new BucketLong(type); + } else if (type instanceof TimestampType) { + return new BucketLong(type); + } else if (type instanceof TimestampNTZType) { + return new BucketLong(type); + } else if (type instanceof DecimalType) { + return new BucketDecimal(type); + } else if (type instanceof StringType) { + return new BucketString(); + } else if (type instanceof BinaryType) { + return new BucketBinary(); + } else { + throw new UnsupportedOperationException( + "Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + } } else { - throw new UnsupportedOperationException( - "Expected column to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + StructField[] tailInputFields = Arrays.copyOfRange(inputType.fields(), 1, inputType.size()); + boolean isAllPrimitive = + Arrays.stream(tailInputFields) + .map(StructField::dataType) + .allMatch( + x -> SUPPORTED_PRIMITIVE_VALUE_TYPES.contains(x) || (x instanceof DecimalType)); + if (!isAllPrimitive) { + throw new UnsupportedOperationException( + "Expected all columns to be date, tinyint, smallint, int, bigint, decimal, timestamp, string, or binary"); + } + DataType[] sqlTypes = + Arrays.stream(tailInputFields).map(StructField::dataType).toArray(DataType[]::new); + return new BucketMultiple(sqlTypes); } } @Override public String description() { return name() - + "(numBuckets, col) - Call Iceberg's bucket transform\n" + + "(numBuckets, col...) - Call Iceberg's bucket transform\n" + " numBuckets :: number of buckets to divide the rows into, e.g. bucket(100, 34) -> 79 (must be a tinyint, smallint, or int)\n" + " col :: column to bucket (must be a date, integer, long, timestamp, decimal, string, or binary)"; } @@ -324,4 +355,78 @@ public String canonicalName() { return "iceberg.bucket(decimal)"; } } + + public static class BucketMultiple extends BucketBase { + private final DataType[] inputTypes; + + // no magic method: `invoke` here. Because the input types are dynamic, we cannot pre generate + // the `invoke` method that matches the java type. One possible way to solve this would be + // marking this method `Unevaluable` and replace it with an `IcebergBucketExpression` in the + // extended resolution rule. We can implement `IcebergBucketExpression` with codegen support. + + public BucketMultiple(DataType[] sqlTypes) { + this.inputTypes = new DataType[sqlTypes.length + 1]; + this.inputTypes[0] = DataTypes.IntegerType; + // copies the sqlTypes into the inputTypes array starting at index 1 + System.arraycopy(sqlTypes, 0, this.inputTypes, 1, sqlTypes.length); + } + + @Override + public DataType[] inputTypes() { + return inputTypes; + } + + @SuppressWarnings("CyclomaticComplexity") + @Override + public Integer produceResult(InternalRow input) { + if (input.isNullAt(NUM_BUCKETS_ORDINAL)) { + return null; + } else { + int numBuckets = input.getInt(NUM_BUCKETS_ORDINAL); + Hasher hasher = BucketUtil.hasher(); + boolean isNull = true; + for (int i = 1; i < inputTypes.length; i++) { + if (!input.isNullAt(i)) { + isNull = false; + // switch type to get the object + if (inputTypes[i].equals(DataTypes.ByteType) + || inputTypes[i].equals(DataTypes.ShortType) + || inputTypes[i].equals(DataTypes.IntegerType) + || inputTypes[i].equals(DataTypes.DateType)) { + // Byte, Short and Integer are upcasted to int in Spark, Date are treated as int + hasher.putLong(input.getInt(i)); + } else if (inputTypes[i].equals(DataTypes.LongType)) { + hasher.putLong(input.getLong(i)); + } else if (inputTypes[i].equals(DataTypes.StringType)) { + // we can hash UTF8String's bytes directly since it should already be UTF-8 encoded. + hasher.putBytes(input.getUTF8String(i).getBytes()); + } else if (inputTypes[i].equals(DataTypes.BinaryType)) { + // we can hash BinaryType's bytes directly, there's no need to wrapped it in a + // ByteBuffer + hasher.putBytes(input.getBinary(i)); + } else if (inputTypes[i].equals(DataTypes.TimestampType) + || inputTypes[i].equals(DataTypes.TimestampNTZType)) { + hasher.putLong(input.getLong(i)); + } else if (inputTypes[i] instanceof DecimalType) { + DecimalType decimalType = (DecimalType) inputTypes[i]; + int precision = decimalType.precision(); + int scale = decimalType.scale(); + Decimal value = input.getDecimal(VALUE_ORDINAL, precision, scale); + hasher.putBytes(value.toJavaBigDecimal().unscaledValue().toByteArray()); + } else { + throw new UnsupportedOperationException( + "Unsupported type for bucketing: " + inputTypes[i].typeName()); + } + } + } + return isNull ? null : apply(numBuckets, hasher.hash().asInt()); + } + } + + @Override + public String canonicalName() { + // todo: specify the exact type names later + return "iceberg.bucket(multiple)"; + } + } } diff --git a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java index 66cda5b82955..7f112f2db647 100644 --- a/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java +++ b/spark/v3.4/spark/src/main/java/org/apache/iceberg/spark/source/SparkBatchQueryScan.java @@ -102,7 +102,9 @@ public NamedReference[] filterAttributes() { for (PartitionSpec spec : specs()) { for (PartitionField field : spec.fields()) { - partitionFieldSourceIds.add(field.sourceId()); + for (int sourceId : field.sourceIds()) { + partitionFieldSourceIds.add(sourceId); + } } } diff --git a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java index c4113408aff9..6f3ebaeb5751 100644 --- a/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java +++ b/spark/v3.4/spark/src/test/java/org/apache/iceberg/spark/sql/TestSparkBucketFunction.java @@ -200,17 +200,12 @@ public void testWrongNumberOfArguments() { Assertions.assertThatThrownBy(() -> scalarSql("SELECT system.bucket()")) .isInstanceOf(AnalysisException.class) .hasMessageStartingWith( - "Function 'bucket' cannot process input: (): Wrong number of inputs (expected numBuckets and value)"); + "Function 'bucket' cannot process input: (): Wrong number of inputs (expected numBuckets and values)"); Assertions.assertThatThrownBy(() -> scalarSql("SELECT system.bucket(1)")) .isInstanceOf(AnalysisException.class) .hasMessageStartingWith( - "Function 'bucket' cannot process input: (int): Wrong number of inputs (expected numBuckets and value)"); - - Assertions.assertThatThrownBy(() -> scalarSql("SELECT system.bucket(1, 1L, 1)")) - .isInstanceOf(AnalysisException.class) - .hasMessageStartingWith( - "Function 'bucket' cannot process input: (int, bigint, int): Wrong number of inputs (expected numBuckets and value)"); + "Function 'bucket' cannot process input: (int): Wrong number of inputs (expected numBuckets and values)"); } @Test