diff --git a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java index 29426bc97566..4e0cb7793537 100644 --- a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java +++ b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriter.java @@ -32,18 +32,16 @@ import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.orc.TypeDescription; -import org.apache.orc.storage.ql.exec.vector.ColumnVector; -import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; public class GenericOrcWriter implements OrcRowWriter { - private final OrcValueWriter writer; + private final RecordWriter writer; private GenericOrcWriter(Schema expectedSchema, TypeDescription orcSchema) { Preconditions.checkArgument(orcSchema.getCategory() == TypeDescription.Category.STRUCT, "Top level must be a struct " + orcSchema); - writer = OrcSchemaWithTypeVisitor.visit(expectedSchema, orcSchema, new WriteBuilder()); + writer = (RecordWriter) OrcSchemaWithTypeVisitor.visit(expectedSchema, orcSchema, new WriteBuilder()); } public static OrcRowWriter buildWriter(Schema expectedSchema, TypeDescription fileSchema) { @@ -115,17 +113,9 @@ public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescriptio } @Override - @SuppressWarnings("unchecked") public void write(Record value, VectorizedRowBatch output) { - Preconditions.checkArgument(writer instanceof RecordWriter, "writer must be a RecordWriter."); - - int row = output.size; - output.size += 1; - List> writers = ((RecordWriter) writer).writers(); - for (int c = 0; c < writers.size(); ++c) { - OrcValueWriter child = writers.get(c); - child.write(row, value.get(c, child.getJavaClass()), output.cols[c]); - } + Preconditions.checkArgument(value != null, "value must not be null"); + writer.writeRow(value, output); } @Override @@ -133,35 +123,15 @@ public Stream> metrics() { return writer.metrics(); } - private static class RecordWriter implements OrcValueWriter { - private final List> writers; + private static class RecordWriter extends GenericOrcWriters.StructWriter { RecordWriter(List> writers) { - this.writers = writers; - } - - List> writers() { - return writers; - } - - @Override - public Class getJavaClass() { - return Record.class; - } - - @Override - @SuppressWarnings("unchecked") - public void nonNullWrite(int rowId, Record data, ColumnVector output) { - StructColumnVector cv = (StructColumnVector) output; - for (int c = 0; c < writers.size(); ++c) { - OrcValueWriter child = writers.get(c); - child.write(rowId, data.get(c, child.getJavaClass()), cv.fields[c]); - } + super(writers); } @Override - public Stream> metrics() { - return writers.stream().flatMap(OrcValueWriter::metrics); + protected Object get(Record struct, int index) { + return struct.get(index); } } } diff --git a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java index 7efa1613de97..e0d2c5aab90b 100644 --- a/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java +++ b/data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Map; import java.util.UUID; +import java.util.function.Function; import java.util.stream.Stream; import org.apache.iceberg.DoubleFieldMetrics; import org.apache.iceberg.FieldMetrics; @@ -48,7 +49,9 @@ import org.apache.orc.storage.ql.exec.vector.ListColumnVector; import org.apache.orc.storage.ql.exec.vector.LongColumnVector; import org.apache.orc.storage.ql.exec.vector.MapColumnVector; +import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; +import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; public class GenericOrcWriters { private static final OffsetDateTime EPOCH = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC); @@ -138,11 +141,6 @@ public static OrcValueWriter> map(OrcValueWriter key, OrcVal private static class BooleanWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new BooleanWriter(); - @Override - public Class getJavaClass() { - return Boolean.class; - } - @Override public void nonNullWrite(int rowId, Boolean data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data ? 1 : 0; @@ -152,11 +150,6 @@ public void nonNullWrite(int rowId, Boolean data, ColumnVector output) { private static class ByteWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new ByteWriter(); - @Override - public Class getJavaClass() { - return Byte.class; - } - @Override public void nonNullWrite(int rowId, Byte data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data; @@ -166,11 +159,6 @@ public void nonNullWrite(int rowId, Byte data, ColumnVector output) { private static class ShortWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new ShortWriter(); - @Override - public Class getJavaClass() { - return Short.class; - } - @Override public void nonNullWrite(int rowId, Short data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data; @@ -180,11 +168,6 @@ public void nonNullWrite(int rowId, Short data, ColumnVector output) { private static class IntWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new IntWriter(); - @Override - public Class getJavaClass() { - return Integer.class; - } - @Override public void nonNullWrite(int rowId, Integer data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data; @@ -194,11 +177,6 @@ public void nonNullWrite(int rowId, Integer data, ColumnVector output) { private static class TimeWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new TimeWriter(); - @Override - public Class getJavaClass() { - return LocalTime.class; - } - @Override public void nonNullWrite(int rowId, LocalTime data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data.toNanoOfDay() / 1_000; @@ -208,11 +186,6 @@ public void nonNullWrite(int rowId, LocalTime data, ColumnVector output) { private static class LongWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new LongWriter(); - @Override - public Class getJavaClass() { - return Long.class; - } - @Override public void nonNullWrite(int rowId, Long data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data; @@ -227,11 +200,6 @@ private FloatWriter(int id) { this.floatFieldMetricsBuilder = new FloatFieldMetrics.Builder(id); } - @Override - public Class getJavaClass() { - return Float.class; - } - @Override public void nonNullWrite(int rowId, Float data, ColumnVector output) { ((DoubleColumnVector) output).vector[rowId] = data; @@ -261,11 +229,6 @@ private DoubleWriter(Integer id) { this.doubleFieldMetricsBuilder = new DoubleFieldMetrics.Builder(id); } - @Override - public Class getJavaClass() { - return Double.class; - } - @Override public void nonNullWrite(int rowId, Double data, ColumnVector output) { ((DoubleColumnVector) output).vector[rowId] = data; @@ -290,11 +253,6 @@ public Stream> metrics() { private static class StringWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new StringWriter(); - @Override - public Class getJavaClass() { - return String.class; - } - @Override public void nonNullWrite(int rowId, String data, ColumnVector output) { byte[] value = data.getBytes(StandardCharsets.UTF_8); @@ -305,11 +263,6 @@ public void nonNullWrite(int rowId, String data, ColumnVector output) { private static class ByteBufferWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new ByteBufferWriter(); - @Override - public Class getJavaClass() { - return ByteBuffer.class; - } - @Override public void nonNullWrite(int rowId, ByteBuffer data, ColumnVector output) { if (data.hasArray()) { @@ -325,11 +278,6 @@ public void nonNullWrite(int rowId, ByteBuffer data, ColumnVector output) { private static class UUIDWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new UUIDWriter(); - @Override - public Class getJavaClass() { - return UUID.class; - } - @Override @SuppressWarnings("ByteBufferBackingArray") public void nonNullWrite(int rowId, UUID data, ColumnVector output) { @@ -343,11 +291,6 @@ public void nonNullWrite(int rowId, UUID data, ColumnVector output) { private static class ByteArrayWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new ByteArrayWriter(); - @Override - public Class getJavaClass() { - return byte[].class; - } - @Override public void nonNullWrite(int rowId, byte[] data, ColumnVector output) { ((BytesColumnVector) output).setRef(rowId, data, 0, data.length); @@ -357,11 +300,6 @@ public void nonNullWrite(int rowId, byte[] data, ColumnVector output) { private static class DateWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new DateWriter(); - @Override - public Class getJavaClass() { - return LocalDate.class; - } - @Override public void nonNullWrite(int rowId, LocalDate data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = ChronoUnit.DAYS.between(EPOCH_DAY, data); @@ -371,11 +309,6 @@ public void nonNullWrite(int rowId, LocalDate data, ColumnVector output) { private static class TimestampTzWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new TimestampTzWriter(); - @Override - public Class getJavaClass() { - return OffsetDateTime.class; - } - @Override public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; @@ -389,11 +322,6 @@ public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) { private static class TimestampWriter implements OrcValueWriter { private static final OrcValueWriter INSTANCE = new TimestampWriter(); - @Override - public Class getJavaClass() { - return LocalDateTime.class; - } - @Override public void nonNullWrite(int rowId, LocalDateTime data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; @@ -412,11 +340,6 @@ private static class Decimal18Writer implements OrcValueWriter { this.scale = scale; } - @Override - public Class getJavaClass() { - return BigDecimal.class; - } - @Override public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) { Preconditions.checkArgument(data.scale() == scale, @@ -438,11 +361,6 @@ private static class Decimal38Writer implements OrcValueWriter { this.scale = scale; } - @Override - public Class getJavaClass() { - return BigDecimal.class; - } - @Override public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) { Preconditions.checkArgument(data.scale() == scale, @@ -461,11 +379,6 @@ private static class ListWriter implements OrcValueWriter> { this.element = element; } - @Override - public Class getJavaClass() { - return List.class; - } - @Override public void nonNullWrite(int rowId, List value, ColumnVector output) { ListColumnVector cv = (ListColumnVector) output; @@ -496,11 +409,6 @@ private static class MapWriter implements OrcValueWriter> { this.valueWriter = valueWriter; } - @Override - public Class getJavaClass() { - return Map.class; - } - @Override public void nonNullWrite(int rowId, Map map, ColumnVector output) { List keys = Lists.newArrayListWithExpectedSize(map.size()); @@ -531,6 +439,46 @@ public Stream> metrics() { } } + public abstract static class StructWriter implements OrcValueWriter { + private final List> writers; + + protected StructWriter(List> writers) { + this.writers = writers; + } + + public List> writers() { + return writers; + } + + @Override + public Stream> metrics() { + return writers.stream().flatMap(OrcValueWriter::metrics); + } + + @Override + public void nonNullWrite(int rowId, S value, ColumnVector output) { + StructColumnVector cv = (StructColumnVector) output; + write(rowId, value, c -> cv.fields[c]); + } + + // Special case of writing the root struct + public void writeRow(S value, VectorizedRowBatch output) { + int rowId = output.size; + output.size += 1; + write(rowId, value, c -> output.cols[c]); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private void write(int rowId, S value, Function colVectorAtFunc) { + for (int c = 0; c < writers.size(); ++c) { + OrcValueWriter writer = writers.get(c); + writer.write(rowId, get(value, c), colVectorAtFunc.apply(c)); + } + } + + protected abstract Object get(S struct, int index); + } + private static void growColumnVector(ColumnVector cv, int requestedSize) { if (cv.isNull.length < requestedSize) { // Use growth factor of 3 to avoid frequent array allocations diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriter.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriter.java index 81f9822815b6..3f469b755f6b 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriter.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriter.java @@ -37,17 +37,10 @@ import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; public class FlinkOrcWriter implements OrcRowWriter { - private final FlinkOrcWriters.StructWriter writer; - private final List fieldGetters; + private final FlinkOrcWriters.RowDataWriter writer; private FlinkOrcWriter(RowType rowType, Schema iSchema) { - this.writer = (FlinkOrcWriters.StructWriter) FlinkSchemaVisitor.visit(rowType, iSchema, new WriteBuilder()); - - List fieldTypes = rowType.getChildren(); - this.fieldGetters = Lists.newArrayListWithExpectedSize(fieldTypes.size()); - for (int i = 0; i < fieldTypes.size(); i++) { - fieldGetters.add(RowData.createFieldGetter(fieldTypes.get(i), i)); - } + this.writer = (FlinkOrcWriters.RowDataWriter) FlinkSchemaVisitor.visit(rowType, iSchema, new WriteBuilder()); } public static OrcRowWriter buildWriter(RowType rowType, Schema iSchema) { @@ -55,16 +48,9 @@ public static OrcRowWriter buildWriter(RowType rowType, Schema iSchema) } @Override - @SuppressWarnings("unchecked") public void write(RowData row, VectorizedRowBatch output) { - int rowId = output.size; - output.size += 1; - - List> writers = writer.writers(); - for (int c = 0; c < writers.size(); ++c) { - OrcValueWriter child = writers.get(c); - child.write(rowId, fieldGetters.get(c).getFieldOrNull(row), output.cols[c]); - } + Preconditions.checkArgument(row != null, "value must not be null"); + writer.writeRow(row, output); } @Override diff --git a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriters.java b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriters.java index 38a348995f00..6b596ac2063c 100644 --- a/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriters.java +++ b/flink/src/main/java/org/apache/iceberg/flink/data/FlinkOrcWriters.java @@ -32,6 +32,7 @@ import org.apache.flink.table.data.TimestampData; import org.apache.flink.table.types.logical.LogicalType; import org.apache.iceberg.FieldMetrics; +import org.apache.iceberg.data.orc.GenericOrcWriters; import org.apache.iceberg.orc.OrcValueWriter; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; @@ -42,7 +43,6 @@ import org.apache.orc.storage.ql.exec.vector.ListColumnVector; import org.apache.orc.storage.ql.exec.vector.LongColumnVector; import org.apache.orc.storage.ql.exec.vector.MapColumnVector; -import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; class FlinkOrcWriters { @@ -90,17 +90,12 @@ static OrcValueWriter map(OrcValueWriter keyWriter, OrcValueW } static OrcValueWriter struct(List> writers, List types) { - return new StructWriter(writers, types); + return new RowDataWriter(writers, types); } private static class StringWriter implements OrcValueWriter { private static final StringWriter INSTANCE = new StringWriter(); - @Override - public Class getJavaClass() { - return StringData.class; - } - @Override public void nonNullWrite(int rowId, StringData data, ColumnVector output) { byte[] value = data.toBytes(); @@ -111,11 +106,6 @@ public void nonNullWrite(int rowId, StringData data, ColumnVector output) { private static class DateWriter implements OrcValueWriter { private static final DateWriter INSTANCE = new DateWriter(); - @Override - public Class getJavaClass() { - return Integer.class; - } - @Override public void nonNullWrite(int rowId, Integer data, ColumnVector output) { ((LongColumnVector) output).vector[rowId] = data; @@ -125,11 +115,6 @@ public void nonNullWrite(int rowId, Integer data, ColumnVector output) { private static class TimeWriter implements OrcValueWriter { private static final TimeWriter INSTANCE = new TimeWriter(); - @Override - public Class getJavaClass() { - return Integer.class; - } - @Override public void nonNullWrite(int rowId, Integer millis, ColumnVector output) { // The time in flink is in millisecond, while the standard time in iceberg is microsecond. @@ -141,11 +126,6 @@ public void nonNullWrite(int rowId, Integer millis, ColumnVector output) { private static class TimestampWriter implements OrcValueWriter { private static final TimestampWriter INSTANCE = new TimestampWriter(); - @Override - public Class getJavaClass() { - return TimestampData.class; - } - @Override public void nonNullWrite(int rowId, TimestampData data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; @@ -161,11 +141,6 @@ public void nonNullWrite(int rowId, TimestampData data, ColumnVector output) { private static class TimestampTzWriter implements OrcValueWriter { private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); - @Override - public Class getJavaClass() { - return TimestampData.class; - } - @Override public void nonNullWrite(int rowId, TimestampData data, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; @@ -186,11 +161,6 @@ private static class Decimal18Writer implements OrcValueWriter { this.scale = scale; } - @Override - public Class getJavaClass() { - return DecimalData.class; - } - @Override public void nonNullWrite(int rowId, DecimalData data, ColumnVector output) { Preconditions.checkArgument(scale == data.scale(), @@ -211,11 +181,6 @@ private static class Decimal38Writer implements OrcValueWriter { this.scale = scale; } - @Override - public Class getJavaClass() { - return DecimalData.class; - } - @Override public void nonNullWrite(int rowId, DecimalData data, ColumnVector output) { Preconditions.checkArgument(scale == data.scale(), @@ -236,11 +201,6 @@ static class ListWriter implements OrcValueWriter { this.elementGetter = ArrayData.createElementGetter(elementType); } - @Override - public Class getJavaClass() { - return ArrayData.class; - } - @Override @SuppressWarnings("unchecked") public void nonNullWrite(int rowId, ArrayData data, ColumnVector output) { @@ -278,11 +238,6 @@ static class MapWriter implements OrcValueWriter { this.valueGetter = ArrayData.createElementGetter(valueType); } - @Override - public Class getJavaClass() { - return MapData.class; - } - @Override @SuppressWarnings("unchecked") public void nonNullWrite(int rowId, MapData data, ColumnVector output) { @@ -311,12 +266,11 @@ public Stream> metrics() { } } - static class StructWriter implements OrcValueWriter { - private final List> writers; + static class RowDataWriter extends GenericOrcWriters.StructWriter { private final List fieldGetters; - StructWriter(List> writers, List types) { - this.writers = writers; + RowDataWriter(List> writers, List types) { + super(writers); this.fieldGetters = Lists.newArrayListWithExpectedSize(types.size()); for (int i = 0; i < types.size(); i++) { @@ -324,29 +278,11 @@ static class StructWriter implements OrcValueWriter { } } - List> writers() { - return writers; - } - @Override - public Class getJavaClass() { - return RowData.class; + protected Object get(RowData struct, int index) { + return fieldGetters.get(index).getFieldOrNull(struct); } - @Override - @SuppressWarnings("unchecked") - public void nonNullWrite(int rowId, RowData data, ColumnVector output) { - StructColumnVector cv = (StructColumnVector) output; - for (int c = 0; c < writers.size(); ++c) { - OrcValueWriter writer = writers.get(c); - writer.write(rowId, fieldGetters.get(c).getFieldOrNull(data), cv.fields[c]); - } - } - - @Override - public Stream> metrics() { - return writers.stream().flatMap(OrcValueWriter::metrics); - } } private static void growColumnVector(ColumnVector cv, int requestedSize) { diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcValueWriter.java b/orc/src/main/java/org/apache/iceberg/orc/OrcValueWriter.java index b6030abb7a78..d8c27ac30879 100644 --- a/orc/src/main/java/org/apache/iceberg/orc/OrcValueWriter.java +++ b/orc/src/main/java/org/apache/iceberg/orc/OrcValueWriter.java @@ -25,8 +25,6 @@ public interface OrcValueWriter { - Class getJavaClass(); - /** * Take a value from the data value and add it to the ORC output. * diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java deleted file mode 100644 index b4124468687f..000000000000 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriter.java +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -package org.apache.iceberg.spark.data; - -import java.util.stream.Stream; -import org.apache.iceberg.FieldMetrics; -import org.apache.orc.storage.ql.exec.vector.ColumnVector; -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; - -interface SparkOrcValueWriter { - - /** - * Take a value from the data and add it to the ORC output. - * - * @param rowId the row id in the ColumnVector. - * @param column the column number. - * @param data the data value to write. - * @param output the ColumnVector to put the value into. - */ - default void write(int rowId, int column, SpecializedGetters data, ColumnVector output) { - if (data.isNullAt(column)) { - output.noNulls = false; - output.isNull[rowId] = true; - } else { - output.isNull[rowId] = false; - nonNullWrite(rowId, column, data, output); - } - } - - void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output); - - /** - * Returns a stream of {@link FieldMetrics} that this SparkOrcValueWriter keeps track of. - *

- * Since ORC keeps track of most metrics via column statistics, for now SparkOrcValueWriter only keeps track of NaN - * counters, and only return non-empty stream if the writer writes double or float values either by itself or - * transitively. - */ - default Stream> metrics() { - return Stream.empty(); - } -} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java index df1b079bc7fa..abb12dffc050 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueWriters.java @@ -19,245 +19,115 @@ package org.apache.iceberg.spark.data; +import java.util.List; import java.util.stream.Stream; -import org.apache.iceberg.DoubleFieldMetrics; import org.apache.iceberg.FieldMetrics; -import org.apache.iceberg.FloatFieldMetrics; +import org.apache.iceberg.orc.OrcValueWriter; +import org.apache.orc.TypeDescription; import org.apache.orc.storage.common.type.HiveDecimal; import org.apache.orc.storage.ql.exec.vector.BytesColumnVector; import org.apache.orc.storage.ql.exec.vector.ColumnVector; import org.apache.orc.storage.ql.exec.vector.DecimalColumnVector; -import org.apache.orc.storage.ql.exec.vector.DoubleColumnVector; import org.apache.orc.storage.ql.exec.vector.ListColumnVector; -import org.apache.orc.storage.ql.exec.vector.LongColumnVector; import org.apache.orc.storage.ql.exec.vector.MapColumnVector; import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector; -import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; class SparkOrcValueWriters { private SparkOrcValueWriters() { } - static SparkOrcValueWriter booleans() { - return BooleanWriter.INSTANCE; - } - - static SparkOrcValueWriter bytes() { - return ByteWriter.INSTANCE; - } - - static SparkOrcValueWriter shorts() { - return ShortWriter.INSTANCE; - } - - static SparkOrcValueWriter ints() { - return IntWriter.INSTANCE; - } - - static SparkOrcValueWriter longs() { - return LongWriter.INSTANCE; - } - - static SparkOrcValueWriter floats(int id) { - return new FloatWriter(id); - } - - static SparkOrcValueWriter doubles(int id) { - return new DoubleWriter(id); - } - - static SparkOrcValueWriter byteArrays() { - return BytesWriter.INSTANCE; - } - - static SparkOrcValueWriter strings() { + static OrcValueWriter strings() { return StringWriter.INSTANCE; } - static SparkOrcValueWriter timestampTz() { + static OrcValueWriter timestampTz() { return TimestampTzWriter.INSTANCE; } - static SparkOrcValueWriter decimal(int precision, int scale) { + static OrcValueWriter decimal(int precision, int scale) { if (precision <= 18) { - return new Decimal18Writer(precision, scale); + return new Decimal18Writer(scale); } else { - return new Decimal38Writer(precision, scale); + return new Decimal38Writer(); } } - static SparkOrcValueWriter list(SparkOrcValueWriter element) { - return new ListWriter(element); + static OrcValueWriter list(OrcValueWriter element, List orcType) { + return new ListWriter<>(element, orcType); } - static SparkOrcValueWriter map(SparkOrcValueWriter keyWriter, SparkOrcValueWriter valueWriter) { - return new MapWriter(keyWriter, valueWriter); - } - - private static class BooleanWriter implements SparkOrcValueWriter { - private static final BooleanWriter INSTANCE = new BooleanWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ((LongColumnVector) output).vector[rowId] = data.getBoolean(column) ? 1 : 0; - } - } - - private static class ByteWriter implements SparkOrcValueWriter { - private static final ByteWriter INSTANCE = new ByteWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ((LongColumnVector) output).vector[rowId] = data.getByte(column); - } - } - - private static class ShortWriter implements SparkOrcValueWriter { - private static final ShortWriter INSTANCE = new ShortWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ((LongColumnVector) output).vector[rowId] = data.getShort(column); - } - } - - private static class IntWriter implements SparkOrcValueWriter { - private static final IntWriter INSTANCE = new IntWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ((LongColumnVector) output).vector[rowId] = data.getInt(column); - } - } - - private static class LongWriter implements SparkOrcValueWriter { - private static final LongWriter INSTANCE = new LongWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ((LongColumnVector) output).vector[rowId] = data.getLong(column); - } + static OrcValueWriter map(OrcValueWriter keyWriter, OrcValueWriter valueWriter, + List orcTypes) { + return new MapWriter<>(keyWriter, valueWriter, orcTypes); } - private static class FloatWriter implements SparkOrcValueWriter { - private final FloatFieldMetrics.Builder floatFieldMetricsBuilder; - - private FloatWriter(int id) { - this.floatFieldMetricsBuilder = new FloatFieldMetrics.Builder(id); - } - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - float floatValue = data.getFloat(column); - ((DoubleColumnVector) output).vector[rowId] = floatValue; - floatFieldMetricsBuilder.addValue(floatValue); - } - - @Override - public Stream> metrics() { - return Stream.of(floatFieldMetricsBuilder.build()); - } - } - - private static class DoubleWriter implements SparkOrcValueWriter { - private final DoubleFieldMetrics.Builder doubleFieldMetricsBuilder; - - private DoubleWriter(int id) { - this.doubleFieldMetricsBuilder = new DoubleFieldMetrics.Builder(id); - } - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - double doubleValue = data.getDouble(column); - ((DoubleColumnVector) output).vector[rowId] = doubleValue; - doubleFieldMetricsBuilder.addValue(doubleValue); - } - - @Override - public Stream> metrics() { - return Stream.of(doubleFieldMetricsBuilder.build()); - } - } - - private static class BytesWriter implements SparkOrcValueWriter { - private static final BytesWriter INSTANCE = new BytesWriter(); - - @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - // getBinary always makes a copy, so we don't need to worry about it - // being changed behind our back. - byte[] value = data.getBinary(column); - ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); - } - } - - private static class StringWriter implements SparkOrcValueWriter { + private static class StringWriter implements OrcValueWriter { private static final StringWriter INSTANCE = new StringWriter(); @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - byte[] value = data.getUTF8String(column).getBytes(); + public void nonNullWrite(int rowId, UTF8String data, ColumnVector output) { + byte[] value = data.getBytes(); ((BytesColumnVector) output).setRef(rowId, value, 0, value.length); } + } - private static class TimestampTzWriter implements SparkOrcValueWriter { + private static class TimestampTzWriter implements OrcValueWriter { private static final TimestampTzWriter INSTANCE = new TimestampTzWriter(); @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + public void nonNullWrite(int rowId, Long micros, ColumnVector output) { TimestampColumnVector cv = (TimestampColumnVector) output; - long micros = data.getLong(column); // it could be negative. cv.time[rowId] = Math.floorDiv(micros, 1_000); // millis cv.nanos[rowId] = (int) Math.floorMod(micros, 1_000_000) * 1_000; // nanos } + } - private static class Decimal18Writer implements SparkOrcValueWriter { - private final int precision; + private static class Decimal18Writer implements OrcValueWriter { private final int scale; - Decimal18Writer(int precision, int scale) { - this.precision = precision; + Decimal18Writer(int scale) { this.scale = scale; } @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { ((DecimalColumnVector) output).vector[rowId].setFromLongAndScale( - data.getDecimal(column, precision, scale).toUnscaledLong(), scale); + decimal.toUnscaledLong(), scale); } - } - private static class Decimal38Writer implements SparkOrcValueWriter { - private final int precision; - private final int scale; + } - Decimal38Writer(int precision, int scale) { - this.precision = precision; - this.scale = scale; - } + private static class Decimal38Writer implements OrcValueWriter { @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { + public void nonNullWrite(int rowId, Decimal decimal, ColumnVector output) { ((DecimalColumnVector) output).vector[rowId].set( - HiveDecimal.create(data.getDecimal(column, precision, scale) - .toJavaBigDecimal())); + HiveDecimal.create(decimal.toJavaBigDecimal())); } + } - private static class ListWriter implements SparkOrcValueWriter { - private final SparkOrcValueWriter writer; + private static class ListWriter implements OrcValueWriter { + private final OrcValueWriter writer; + private final SparkOrcWriter.FieldGetter fieldGetter; - ListWriter(SparkOrcValueWriter writer) { + @SuppressWarnings("unchecked") + ListWriter(OrcValueWriter writer, List orcTypes) { + if (orcTypes.size() != 1) { + throw new IllegalArgumentException("Expected one (and same) ORC type for list elements, got: " + orcTypes); + } this.writer = writer; + this.fieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); } @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - ArrayData value = data.getArray(column); + public void nonNullWrite(int rowId, ArrayData value, ColumnVector output) { ListColumnVector cv = (ListColumnVector) output; // record the length and start of the list elements cv.lengths[rowId] = value.numElements(); @@ -267,7 +137,7 @@ public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnV growColumnVector(cv.child, cv.childCount); // Add each element for (int e = 0; e < cv.lengths[rowId]; ++e) { - writer.write((int) (e + cv.offsets[rowId]), e, value, cv.child); + writer.write((int) (e + cv.offsets[rowId]), fieldGetter.getFieldOrNull(value, e), cv.child); } } @@ -275,20 +145,28 @@ public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnV public Stream> metrics() { return writer.metrics(); } + } - private static class MapWriter implements SparkOrcValueWriter { - private final SparkOrcValueWriter keyWriter; - private final SparkOrcValueWriter valueWriter; + private static class MapWriter implements OrcValueWriter { + private final OrcValueWriter keyWriter; + private final OrcValueWriter valueWriter; + private final SparkOrcWriter.FieldGetter keyFieldGetter; + private final SparkOrcWriter.FieldGetter valueFieldGetter; - MapWriter(SparkOrcValueWriter keyWriter, SparkOrcValueWriter valueWriter) { + @SuppressWarnings("unchecked") + MapWriter(OrcValueWriter keyWriter, OrcValueWriter valueWriter, List orcTypes) { + if (orcTypes.size() != 2) { + throw new IllegalArgumentException("Expected two ORC type descriptions for a map, got: " + orcTypes); + } this.keyWriter = keyWriter; this.valueWriter = valueWriter; + this.keyFieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(0)); + this.valueFieldGetter = (SparkOrcWriter.FieldGetter) SparkOrcWriter.createFieldGetter(orcTypes.get(1)); } @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - MapData map = data.getMap(column); + public void nonNullWrite(int rowId, MapData map, ColumnVector output) { ArrayData key = map.keyArray(); ArrayData value = map.valueArray(); MapColumnVector cv = (MapColumnVector) output; @@ -302,8 +180,8 @@ public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnV // Add each element for (int e = 0; e < cv.lengths[rowId]; ++e) { int pos = (int) (e + cv.offsets[rowId]); - keyWriter.write(pos, e, key, cv.keys); - valueWriter.write(pos, e, value, cv.values); + keyWriter.write(pos, keyFieldGetter.getFieldOrNull(key, e), cv.keys); + valueWriter.write(pos, valueFieldGetter.getFieldOrNull(value, e), cv.values); } } @@ -311,6 +189,7 @@ public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnV public Stream> metrics() { return Stream.concat(keyWriter.metrics(), valueWriter.metrics()); } + } private static void growColumnVector(ColumnVector cv, int requestedSize) { diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java index 2c1edea1ffef..34292f23b135 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java +++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcWriter.java @@ -19,19 +19,22 @@ package org.apache.iceberg.spark.data; +import java.io.Serializable; import java.util.List; import java.util.stream.Stream; +import javax.annotation.Nullable; import org.apache.iceberg.FieldMetrics; import org.apache.iceberg.Schema; +import org.apache.iceberg.data.orc.GenericOrcWriters; import org.apache.iceberg.orc.ORCSchemaUtil; import org.apache.iceberg.orc.OrcRowWriter; import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor; +import org.apache.iceberg.orc.OrcValueWriter; import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.orc.TypeDescription; -import org.apache.orc.storage.ql.exec.vector.ColumnVector; -import org.apache.orc.storage.ql.exec.vector.StructColumnVector; import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.SpecializedGetters; @@ -42,26 +45,19 @@ */ public class SparkOrcWriter implements OrcRowWriter { - private final SparkOrcValueWriter writer; + private final InternalRowWriter writer; public SparkOrcWriter(Schema iSchema, TypeDescription orcSchema) { Preconditions.checkArgument(orcSchema.getCategory() == TypeDescription.Category.STRUCT, "Top level must be a struct " + orcSchema); - writer = OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); + writer = (InternalRowWriter) OrcSchemaWithTypeVisitor.visit(iSchema, orcSchema, new WriteBuilder()); } @Override public void write(InternalRow value, VectorizedRowBatch output) { - Preconditions.checkArgument(writer instanceof StructWriter, "writer must be StructWriter"); - - int row = output.size; - output.size += 1; - List writers = ((StructWriter) writer).writers(); - for (int c = 0; c < writers.size(); c++) { - SparkOrcValueWriter child = writers.get(c); - child.write(row, c, value, output.cols[c]); - } + Preconditions.checkArgument(value != null, "value must not be null"); + writer.writeRow(value, output); } @Override @@ -69,48 +65,48 @@ public Stream> metrics() { return writer.metrics(); } - private static class WriteBuilder extends OrcSchemaWithTypeVisitor { + private static class WriteBuilder extends OrcSchemaWithTypeVisitor> { private WriteBuilder() { } @Override - public SparkOrcValueWriter record(Types.StructType iStruct, TypeDescription record, - List names, List fields) { - return new StructWriter(fields); + public OrcValueWriter record(Types.StructType iStruct, TypeDescription record, + List names, List> fields) { + return new InternalRowWriter(fields, record.getChildren()); } @Override - public SparkOrcValueWriter list(Types.ListType iList, TypeDescription array, - SparkOrcValueWriter element) { - return SparkOrcValueWriters.list(element); + public OrcValueWriter list(Types.ListType iList, TypeDescription array, + OrcValueWriter element) { + return SparkOrcValueWriters.list(element, array.getChildren()); } @Override - public SparkOrcValueWriter map(Types.MapType iMap, TypeDescription map, - SparkOrcValueWriter key, SparkOrcValueWriter value) { - return SparkOrcValueWriters.map(key, value); + public OrcValueWriter map(Types.MapType iMap, TypeDescription map, + OrcValueWriter key, OrcValueWriter value) { + return SparkOrcValueWriters.map(key, value, map.getChildren()); } @Override - public SparkOrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { + public OrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescription primitive) { switch (primitive.getCategory()) { case BOOLEAN: - return SparkOrcValueWriters.booleans(); + return GenericOrcWriters.booleans(); case BYTE: - return SparkOrcValueWriters.bytes(); + return GenericOrcWriters.bytes(); case SHORT: - return SparkOrcValueWriters.shorts(); + return GenericOrcWriters.shorts(); case DATE: case INT: - return SparkOrcValueWriters.ints(); + return GenericOrcWriters.ints(); case LONG: - return SparkOrcValueWriters.longs(); + return GenericOrcWriters.longs(); case FLOAT: - return SparkOrcValueWriters.floats(ORCSchemaUtil.fieldId(primitive)); + return GenericOrcWriters.floats(ORCSchemaUtil.fieldId(primitive)); case DOUBLE: - return SparkOrcValueWriters.doubles(ORCSchemaUtil.fieldId(primitive)); + return GenericOrcWriters.doubles(ORCSchemaUtil.fieldId(primitive)); case BINARY: - return SparkOrcValueWriters.byteArrays(); + return GenericOrcWriters.byteArrays(); case STRING: case CHAR: case VARCHAR: @@ -126,30 +122,96 @@ public SparkOrcValueWriter primitive(Type.PrimitiveType iPrimitive, TypeDescript } } - private static class StructWriter implements SparkOrcValueWriter { - private final List writers; + private static class InternalRowWriter extends GenericOrcWriters.StructWriter { + private final List> fieldGetters; - StructWriter(List writers) { - this.writers = writers; - } + InternalRowWriter(List> writers, List orcTypes) { + super(writers); + this.fieldGetters = Lists.newArrayListWithExpectedSize(orcTypes.size()); - List writers() { - return writers; + for (int i = 0; i < orcTypes.size(); i++) { + fieldGetters.add(createFieldGetter(orcTypes.get(i))); + } } @Override - public void nonNullWrite(int rowId, int column, SpecializedGetters data, ColumnVector output) { - InternalRow value = data.getStruct(column, writers.size()); - StructColumnVector cv = (StructColumnVector) output; - for (int c = 0; c < writers.size(); ++c) { - writers.get(c).write(rowId, c, value, cv.fields[c]); - } + protected Object get(InternalRow struct, int index) { + return fieldGetters.get(index).getFieldOrNull(struct, index); } + } - @Override - public Stream> metrics() { - return writers.stream().flatMap(SparkOrcValueWriter::metrics); + static FieldGetter createFieldGetter(TypeDescription fieldType) { + final FieldGetter fieldGetter; + switch (fieldType.getCategory()) { + case BOOLEAN: + fieldGetter = SpecializedGetters::getBoolean; + break; + case BYTE: + fieldGetter = SpecializedGetters::getByte; + break; + case SHORT: + fieldGetter = SpecializedGetters::getShort; + break; + case DATE: + case INT: + fieldGetter = SpecializedGetters::getInt; + break; + case LONG: + case TIMESTAMP: + case TIMESTAMP_INSTANT: + fieldGetter = SpecializedGetters::getLong; + break; + case FLOAT: + fieldGetter = SpecializedGetters::getFloat; + break; + case DOUBLE: + fieldGetter = SpecializedGetters::getDouble; + break; + case BINARY: + fieldGetter = SpecializedGetters::getBinary; + // getBinary always makes a copy, so we don't need to worry about it + // being changed behind our back. + break; + case DECIMAL: + fieldGetter = (row, ordinal) -> + row.getDecimal(ordinal, fieldType.getPrecision(), fieldType.getScale()); + break; + case STRING: + case CHAR: + case VARCHAR: + fieldGetter = SpecializedGetters::getUTF8String; + break; + case STRUCT: + fieldGetter = (row, ordinal) -> row.getStruct(ordinal, fieldType.getChildren().size()); + break; + case LIST: + fieldGetter = SpecializedGetters::getArray; + break; + case MAP: + fieldGetter = SpecializedGetters::getMap; + break; + default: + throw new IllegalArgumentException("Encountered an unsupported ORC type during a write from Spark."); } + return (row, ordinal) -> { + if (row.isNullAt(ordinal)) { + return null; + } + return fieldGetter.getFieldOrNull(row, ordinal); + }; + } + + interface FieldGetter extends Serializable { + + /** + * Returns a value from a complex Spark data holder such ArrayData, InternalRow, etc... + * Calls the appropriate getter for the expected data type. + * @param row Spark's data representation + * @param ordinal index in the data structure (e.g. column index for InterRow, list index in ArrayData, etc..) + * @return field value at ordinal + */ + @Nullable + T getFieldOrNull(SpecializedGetters row, int ordinal); } }