Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,16 @@
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.orc.TypeDescription;
import org.apache.orc.storage.ql.exec.vector.ColumnVector;
import org.apache.orc.storage.ql.exec.vector.StructColumnVector;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;

public class GenericOrcWriter implements OrcRowWriter<Record> {
private final OrcValueWriter writer;
private final RecordWriter writer;

private GenericOrcWriter(Schema expectedSchema, TypeDescription orcSchema) {
Preconditions.checkArgument(orcSchema.getCategory() == TypeDescription.Category.STRUCT,
"Top level must be a struct " + orcSchema);

writer = OrcSchemaWithTypeVisitor.visit(expectedSchema, orcSchema, new WriteBuilder());
writer = (RecordWriter) OrcSchemaWithTypeVisitor.visit(expectedSchema, orcSchema, new WriteBuilder());
}

public static OrcRowWriter<Record> buildWriter(Schema expectedSchema, TypeDescription fileSchema) {
Expand Down Expand Up @@ -115,53 +113,25 @@ public OrcValueWriter<?> primitive(Type.PrimitiveType iPrimitive, TypeDescriptio
}

@Override
@SuppressWarnings("unchecked")
public void write(Record value, VectorizedRowBatch output) {
Preconditions.checkArgument(writer instanceof RecordWriter, "writer must be a RecordWriter.");

int row = output.size;
output.size += 1;
List<OrcValueWriter<?>> writers = ((RecordWriter) writer).writers();
for (int c = 0; c < writers.size(); ++c) {
OrcValueWriter child = writers.get(c);
child.write(row, value.get(c, child.getJavaClass()), output.cols[c]);
}
Preconditions.checkArgument(value != null, "value must not be null");
writer.writeRow(value, output);
}

@Override
public Stream<FieldMetrics<?>> metrics() {
return writer.metrics();
}

private static class RecordWriter implements OrcValueWriter<Record> {
private final List<OrcValueWriter<?>> writers;
private static class RecordWriter extends GenericOrcWriters.StructWriter<Record> {

RecordWriter(List<OrcValueWriter<?>> writers) {
this.writers = writers;
}

List<OrcValueWriter<?>> writers() {
return writers;
}

@Override
public Class<Record> getJavaClass() {
return Record.class;
}

@Override
@SuppressWarnings("unchecked")
public void nonNullWrite(int rowId, Record data, ColumnVector output) {
StructColumnVector cv = (StructColumnVector) output;
for (int c = 0; c < writers.size(); ++c) {
OrcValueWriter child = writers.get(c);
child.write(rowId, data.get(c, child.getJavaClass()), cv.fields[c]);
}
super(writers);
}

@Override
public Stream<FieldMetrics<?>> metrics() {
return writers.stream().flatMap(OrcValueWriter::metrics);
protected Object get(Record struct, int index) {
return struct.get(index);
}
}
}
138 changes: 43 additions & 95 deletions data/src/main/java/org/apache/iceberg/data/orc/GenericOrcWriters.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.function.Function;
import java.util.stream.Stream;
import org.apache.iceberg.DoubleFieldMetrics;
import org.apache.iceberg.FieldMetrics;
Expand All @@ -48,7 +49,9 @@
import org.apache.orc.storage.ql.exec.vector.ListColumnVector;
import org.apache.orc.storage.ql.exec.vector.LongColumnVector;
import org.apache.orc.storage.ql.exec.vector.MapColumnVector;
import org.apache.orc.storage.ql.exec.vector.StructColumnVector;
import org.apache.orc.storage.ql.exec.vector.TimestampColumnVector;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;

public class GenericOrcWriters {
private static final OffsetDateTime EPOCH = Instant.ofEpochSecond(0).atOffset(ZoneOffset.UTC);
Expand Down Expand Up @@ -138,11 +141,6 @@ public static <K, V> OrcValueWriter<Map<K, V>> map(OrcValueWriter<K> key, OrcVal
private static class BooleanWriter implements OrcValueWriter<Boolean> {
private static final OrcValueWriter<Boolean> INSTANCE = new BooleanWriter();

@Override
public Class<Boolean> getJavaClass() {
return Boolean.class;
}

@Override
public void nonNullWrite(int rowId, Boolean data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data ? 1 : 0;
Expand All @@ -152,11 +150,6 @@ public void nonNullWrite(int rowId, Boolean data, ColumnVector output) {
private static class ByteWriter implements OrcValueWriter<Byte> {
private static final OrcValueWriter<Byte> INSTANCE = new ByteWriter();

@Override
public Class<Byte> getJavaClass() {
return Byte.class;
}

@Override
public void nonNullWrite(int rowId, Byte data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data;
Expand All @@ -166,11 +159,6 @@ public void nonNullWrite(int rowId, Byte data, ColumnVector output) {
private static class ShortWriter implements OrcValueWriter<Short> {
private static final OrcValueWriter<Short> INSTANCE = new ShortWriter();

@Override
public Class<Short> getJavaClass() {
return Short.class;
}

@Override
public void nonNullWrite(int rowId, Short data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data;
Expand All @@ -180,11 +168,6 @@ public void nonNullWrite(int rowId, Short data, ColumnVector output) {
private static class IntWriter implements OrcValueWriter<Integer> {
private static final OrcValueWriter<Integer> INSTANCE = new IntWriter();

@Override
public Class<Integer> getJavaClass() {
return Integer.class;
}

@Override
public void nonNullWrite(int rowId, Integer data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data;
Expand All @@ -194,11 +177,6 @@ public void nonNullWrite(int rowId, Integer data, ColumnVector output) {
private static class TimeWriter implements OrcValueWriter<LocalTime> {
private static final OrcValueWriter<LocalTime> INSTANCE = new TimeWriter();

@Override
public Class<LocalTime> getJavaClass() {
return LocalTime.class;
}

@Override
public void nonNullWrite(int rowId, LocalTime data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data.toNanoOfDay() / 1_000;
Expand All @@ -208,11 +186,6 @@ public void nonNullWrite(int rowId, LocalTime data, ColumnVector output) {
private static class LongWriter implements OrcValueWriter<Long> {
private static final OrcValueWriter<Long> INSTANCE = new LongWriter();

@Override
public Class<Long> getJavaClass() {
return Long.class;
}

@Override
public void nonNullWrite(int rowId, Long data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = data;
Expand All @@ -227,11 +200,6 @@ private FloatWriter(int id) {
this.floatFieldMetricsBuilder = new FloatFieldMetrics.Builder(id);
}

@Override
public Class<Float> getJavaClass() {
return Float.class;
}

@Override
public void nonNullWrite(int rowId, Float data, ColumnVector output) {
((DoubleColumnVector) output).vector[rowId] = data;
Expand Down Expand Up @@ -261,11 +229,6 @@ private DoubleWriter(Integer id) {
this.doubleFieldMetricsBuilder = new DoubleFieldMetrics.Builder(id);
}

@Override
public Class<Double> getJavaClass() {
return Double.class;
}

@Override
public void nonNullWrite(int rowId, Double data, ColumnVector output) {
((DoubleColumnVector) output).vector[rowId] = data;
Expand All @@ -290,11 +253,6 @@ public Stream<FieldMetrics<?>> metrics() {
private static class StringWriter implements OrcValueWriter<String> {
private static final OrcValueWriter<String> INSTANCE = new StringWriter();

@Override
public Class<String> getJavaClass() {
return String.class;
}

@Override
public void nonNullWrite(int rowId, String data, ColumnVector output) {
byte[] value = data.getBytes(StandardCharsets.UTF_8);
Expand All @@ -305,11 +263,6 @@ public void nonNullWrite(int rowId, String data, ColumnVector output) {
private static class ByteBufferWriter implements OrcValueWriter<ByteBuffer> {
private static final OrcValueWriter<ByteBuffer> INSTANCE = new ByteBufferWriter();

@Override
public Class<ByteBuffer> getJavaClass() {
return ByteBuffer.class;
}

@Override
public void nonNullWrite(int rowId, ByteBuffer data, ColumnVector output) {
if (data.hasArray()) {
Expand All @@ -325,11 +278,6 @@ public void nonNullWrite(int rowId, ByteBuffer data, ColumnVector output) {
private static class UUIDWriter implements OrcValueWriter<UUID> {
private static final OrcValueWriter<UUID> INSTANCE = new UUIDWriter();

@Override
public Class<UUID> getJavaClass() {
return UUID.class;
}

@Override
@SuppressWarnings("ByteBufferBackingArray")
public void nonNullWrite(int rowId, UUID data, ColumnVector output) {
Expand All @@ -343,11 +291,6 @@ public void nonNullWrite(int rowId, UUID data, ColumnVector output) {
private static class ByteArrayWriter implements OrcValueWriter<byte[]> {
private static final OrcValueWriter<byte[]> INSTANCE = new ByteArrayWriter();

@Override
public Class<byte[]> getJavaClass() {
return byte[].class;
}

@Override
public void nonNullWrite(int rowId, byte[] data, ColumnVector output) {
((BytesColumnVector) output).setRef(rowId, data, 0, data.length);
Expand All @@ -357,11 +300,6 @@ public void nonNullWrite(int rowId, byte[] data, ColumnVector output) {
private static class DateWriter implements OrcValueWriter<LocalDate> {
private static final OrcValueWriter<LocalDate> INSTANCE = new DateWriter();

@Override
public Class<LocalDate> getJavaClass() {
return LocalDate.class;
}

@Override
public void nonNullWrite(int rowId, LocalDate data, ColumnVector output) {
((LongColumnVector) output).vector[rowId] = ChronoUnit.DAYS.between(EPOCH_DAY, data);
Expand All @@ -371,11 +309,6 @@ public void nonNullWrite(int rowId, LocalDate data, ColumnVector output) {
private static class TimestampTzWriter implements OrcValueWriter<OffsetDateTime> {
private static final OrcValueWriter<OffsetDateTime> INSTANCE = new TimestampTzWriter();

@Override
public Class<OffsetDateTime> getJavaClass() {
return OffsetDateTime.class;
}

@Override
public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
Expand All @@ -389,11 +322,6 @@ public void nonNullWrite(int rowId, OffsetDateTime data, ColumnVector output) {
private static class TimestampWriter implements OrcValueWriter<LocalDateTime> {
private static final OrcValueWriter<LocalDateTime> INSTANCE = new TimestampWriter();

@Override
public Class<LocalDateTime> getJavaClass() {
return LocalDateTime.class;
}

@Override
public void nonNullWrite(int rowId, LocalDateTime data, ColumnVector output) {
TimestampColumnVector cv = (TimestampColumnVector) output;
Expand All @@ -412,11 +340,6 @@ private static class Decimal18Writer implements OrcValueWriter<BigDecimal> {
this.scale = scale;
}

@Override
public Class<BigDecimal> getJavaClass() {
return BigDecimal.class;
}

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
Preconditions.checkArgument(data.scale() == scale,
Expand All @@ -438,11 +361,6 @@ private static class Decimal38Writer implements OrcValueWriter<BigDecimal> {
this.scale = scale;
}

@Override
public Class<BigDecimal> getJavaClass() {
return BigDecimal.class;
}

@Override
public void nonNullWrite(int rowId, BigDecimal data, ColumnVector output) {
Preconditions.checkArgument(data.scale() == scale,
Expand All @@ -461,11 +379,6 @@ private static class ListWriter<T> implements OrcValueWriter<List<T>> {
this.element = element;
}

@Override
public Class<?> getJavaClass() {
return List.class;
}

@Override
public void nonNullWrite(int rowId, List<T> value, ColumnVector output) {
ListColumnVector cv = (ListColumnVector) output;
Expand Down Expand Up @@ -496,11 +409,6 @@ private static class MapWriter<K, V> implements OrcValueWriter<Map<K, V>> {
this.valueWriter = valueWriter;
}

@Override
public Class<?> getJavaClass() {
return Map.class;
}

@Override
public void nonNullWrite(int rowId, Map<K, V> map, ColumnVector output) {
List<K> keys = Lists.newArrayListWithExpectedSize(map.size());
Expand Down Expand Up @@ -531,6 +439,46 @@ public Stream<FieldMetrics<?>> metrics() {
}
}

public abstract static class StructWriter<S> implements OrcValueWriter<S> {
private final List<OrcValueWriter<?>> writers;

protected StructWriter(List<OrcValueWriter<?>> writers) {
this.writers = writers;
}

public List<OrcValueWriter<?>> writers() {
return writers;
}

@Override
public Stream<FieldMetrics<?>> metrics() {
return writers.stream().flatMap(OrcValueWriter::metrics);
}

@Override
public void nonNullWrite(int rowId, S value, ColumnVector output) {
StructColumnVector cv = (StructColumnVector) output;
write(rowId, value, c -> cv.fields[c]);
}

// Special case of writing the root struct
public void writeRow(S value, VectorizedRowBatch output) {
int rowId = output.size;
output.size += 1;
write(rowId, value, c -> output.cols[c]);
}

@SuppressWarnings({"unchecked", "rawtypes"})
private void write(int rowId, S value, Function<Integer, ColumnVector> colVectorAtFunc) {
for (int c = 0; c < writers.size(); ++c) {
OrcValueWriter writer = writers.get(c);
writer.write(rowId, get(value, c), colVectorAtFunc.apply(c));
}
}

protected abstract Object get(S struct, int index);
}

private static void growColumnVector(ColumnVector cv, int requestedSize) {
if (cv.isNull.length < requestedSize) {
// Use growth factor of 3 to avoid frequent array allocations
Expand Down
Loading