From ee6cde38c3772239a18110ec8ce4b075a2ee49a9 Mon Sep 17 00:00:00 2001 From: Henry Mai Date: Mon, 19 Sep 2022 21:00:03 +0000 Subject: [PATCH] GH-35352: [Java] Fix issues with "semi complex" types. "semi complex" types like TimeStamp*TZ, Duration, and FixedSizeBinary were missing implementations in UnionListWriter, UnionVector, UnionReader and other associated classes. This patch adds these missing methods so that these types can now be written to things like ListVectors, whereas before it would throw an exception because the methods were just not implemented. For example, without this patch, one of the new tests added would fail: ``` TestListVector.testWriterGetTimestampMilliTZField:913 ? IllegalArgument You tried to write a TimeStampMilliTZ type when you are using a ValueWriter of type UnionListWriter. ``` There are also fixes for get and set methods for holders for the respective *Vectors classes for these types: - The get methods did not set fields like TimeStampMilliTZHolder.timezone, DurationHolder.unit, FixedSizeBinaryHolder.byteWidth. - The set methods did not all validate that those fields matched what the vector's ArrowType was set to. For example TimeStampMilliTZHolder.timezone should match ArrowType.Timestamp.timezone on the vector and should throw if it doesn't. Otherwise users would never get a signal that there is anything wrong with their code writing these holders with mismatching values. This patch additionally marks some of the existing interfaces for writing these "semi complex" types as deprecated, because they do not actually provide enough context to properly write these parameterized ArrowTypes. Instead, users should use the write methods that take *Holders that do provide enough context for the writers to properly write these types. Also note that the equivalent write methods for Decimal have already also been marked as deprecated, for the same reasoning. --- .../AbstractPromotableFieldWriter.java | 83 ++++++- .../codegen/templates/ComplexWriters.java | 21 +- .../main/codegen/templates/StructWriters.java | 19 ++ .../codegen/templates/UnionListWriter.java | 128 ++++------ .../main/codegen/templates/UnionReader.java | 9 +- .../main/codegen/templates/UnionVector.java | 64 +++-- .../main/codegen/templates/UnionWriter.java | 87 +++++-- .../apache/arrow/vector/DurationVector.java | 8 + .../arrow/vector/FixedSizeBinaryVector.java | 10 +- .../arrow/vector/TimeStampMicroTZVector.java | 8 + .../arrow/vector/TimeStampMilliTZVector.java | 8 + .../arrow/vector/TimeStampNanoTZVector.java | 8 + .../arrow/vector/TimeStampSecTZVector.java | 8 + .../vector/complex/impl/PromotableWriter.java | 12 +- .../vector/TestFixedSizeBinaryVector.java | 16 +- .../apache/arrow/vector/TestListVector.java | 144 +++++++++++ .../apache/arrow/vector/TestUnionVector.java | 23 ++ .../complex/impl/TestPromotableWriter.java | 234 +++++++++++++++++- .../complex/writer/TestComplexWriter.java | 168 ++++++++++++- 19 files changed, 909 insertions(+), 149 deletions(-) diff --git a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java index 264e85021858..2f963a9df0d0 100644 --- a/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java +++ b/java/vector/src/main/codegen/templates/AbstractPromotableFieldWriter.java @@ -25,6 +25,10 @@ <#include "/@includes/vv_imports.ftl" /> +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + /* * A FieldWriter which delegates calls to another FieldWriter. The delegate FieldWriter can be promoted to a new type * when necessary. Classes that extend this class are responsible for handling promotion. @@ -105,17 +109,7 @@ public void endEntry() { <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> - <#if minor.class != "Decimal" && minor.class != "Decimal256"> - @Override - public void write(${name}Holder holder) { - getWriter(MinorType.${name?upper_case}).write(holder); - } - - public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { - getWriter(MinorType.${name?upper_case}).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); - } - - <#elseif minor.class == "Decimal"> + <#if minor.class == "Decimal"> @Override public void write(DecimalHolder holder) { getWriter(MinorType.DECIMAL).write(holder); @@ -156,8 +150,75 @@ public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType) { public void writeBigEndianBytesToDecimal256(byte[] value) { getWriter(MinorType.DECIMAL256).writeBigEndianBytesToDecimal256(value); } + <#elseif is_timestamp_tz(minor.class)> + @Override + public void write(${name}Holder holder) { + ArrowType.Timestamp arrowTypeWithoutTz = (ArrowType.Timestamp) MinorType.${name?upper_case?remove_ending("TZ")}.getType(); + // Take the holder.timezone similar to how PromotableWriter.java:write(DecimalHolder) takes the scale from the holder. + ArrowType.Timestamp arrowType = new ArrowType.Timestamp(arrowTypeWithoutTz.getUnit(), holder.timezone); + getWriter(MinorType.${name?upper_case}, arrowType).write(holder); + } + /** + * @deprecated + * The holder version should be used instead otherwise the timezone will default to UTC. + * @see #write(${name}Holder) + */ + @Deprecated + @Override + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { + ArrowType.Timestamp arrowTypeWithoutTz = (ArrowType.Timestamp) MinorType.${name?upper_case?remove_ending("TZ")}.getType(); + // Assumes UTC if no timezone is provided + ArrowType.Timestamp arrowType = new ArrowType.Timestamp(arrowTypeWithoutTz.getUnit(), "UTC"); + getWriter(MinorType.${name?upper_case}, arrowType).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); + } + <#elseif minor.class == "Duration"> + @Override + public void write(${name}Holder holder) { + ArrowType.Duration arrowType = new ArrowType.Duration(holder.unit); + getWriter(MinorType.${name?upper_case}, arrowType).write(holder); + } + /** + * @deprecated + * If you experience errors with using this version of the method, switch to the holder version. + * The errors occur when using an untyped or unioned PromotableWriter, because this version of the + * method does not have enough information to infer the ArrowType. + * @see #write(${name}Holder) + */ + @Deprecated + @Override + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { + getWriter(MinorType.${name?upper_case}).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); + } + <#elseif minor.class == "FixedSizeBinary"> + @Override + public void write(${name}Holder holder) { + ArrowType.FixedSizeBinary arrowType = new ArrowType.FixedSizeBinary(holder.byteWidth); + getWriter(MinorType.${name?upper_case}, arrowType).write(holder); + } + + /** + * @deprecated + * If you experience errors with using this version of the method, switch to the holder version. + * The errors occur when using an untyped or unioned PromotableWriter, because this version of the + * method does not have enough information to infer the ArrowType. + * @see #write(${name}Holder) + */ + @Deprecated + @Override + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { + getWriter(MinorType.${name?upper_case}).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); + } + <#else> + @Override + public void write(${name}Holder holder) { + getWriter(MinorType.${name?upper_case}).write(holder); + } + + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { + getWriter(MinorType.${name?upper_case}).write${minor.class}(<#list fields as field>${field.name}<#if field_has_next>, ); + } diff --git a/java/vector/src/main/codegen/templates/ComplexWriters.java b/java/vector/src/main/codegen/templates/ComplexWriters.java index 0381e5559e45..0b1e321afb70 100644 --- a/java/vector/src/main/codegen/templates/ComplexWriters.java +++ b/java/vector/src/main/codegen/templates/ComplexWriters.java @@ -32,6 +32,10 @@ <#include "/@includes/vv_imports.ftl" /> +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + /* * This class is generated using FreeMarker on the ${.template_name} template. */ @@ -191,7 +195,15 @@ public void writeNull() { public interface ${eName}Writer extends BaseWriter { public void write(${minor.class}Holder h); - <#if minor.class?starts_with("Decimal")>@Deprecated +<#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + /** + * @deprecated + * The holder version should be used instead because the plain value version does not contain enough information + * to fully specify this field type. + * @see #write(${minor.class}Holder) + */ + @Deprecated + public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ); <#if minor.class?starts_with("Decimal")> @@ -201,6 +213,13 @@ public interface ${eName}Writer extends BaseWriter { public void writeBigEndianBytesTo${minor.class}(byte[] value, ArrowType arrowType); + /** + * @deprecated + * Use either the version that additionally takes in an ArrowType or use the holder version. + * This version does not contain enough information to fully specify this field type. + * @see #writeBigEndianBytesTo${minor.class}(byte[], ArrowType) + * @see #write(${minor.class}Holder) + */ @Deprecated public void writeBigEndianBytesTo${minor.class}(byte[] value); diff --git a/java/vector/src/main/codegen/templates/StructWriters.java b/java/vector/src/main/codegen/templates/StructWriters.java index 69693c63011c..84e5d8113b32 100644 --- a/java/vector/src/main/codegen/templates/StructWriters.java +++ b/java/vector/src/main/codegen/templates/StructWriters.java @@ -38,6 +38,10 @@ import org.apache.arrow.vector.complex.reader.FieldReader; import org.apache.arrow.vector.complex.writer.FieldWriter; +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + /* * This class is generated using FreeMarker and the ${.template_name} template. */ @@ -314,7 +318,22 @@ public void end() { } else { if (writer instanceof PromotableWriter) { // ensure writers are initialized + <#if minor.class?starts_with("Decimal")> ((PromotableWriter)writer).getWriter(MinorType.${upperName}<#if minor.class?starts_with("Decimal")>, new ${minor.arrowType}(precision, scale, ${vectName}Vector.TYPE_WIDTH * 8)); + <#elseif is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + <#if minor.arrowTypeConstructorParams??> + <#assign constructorParams = minor.arrowTypeConstructorParams /> + <#else> + <#assign constructorParams = [] /> + <#list minor.typeParams?reverse as typeParam> + <#assign constructorParams = constructorParams + [ typeParam.name ] /> + + + ArrowType arrowType = new ${minor.arrowType}(${constructorParams?join(", ")}); + ((PromotableWriter)writer).getWriter(MinorType.${upperName}, arrowType); + <#else> + ((PromotableWriter)writer).getWriter(MinorType.${upperName}); + } } return writer; diff --git a/java/vector/src/main/codegen/templates/UnionListWriter.java b/java/vector/src/main/codegen/templates/UnionListWriter.java index 926276b5eb46..fac75a9ce563 100644 --- a/java/vector/src/main/codegen/templates/UnionListWriter.java +++ b/java/vector/src/main/codegen/templates/UnionListWriter.java @@ -37,6 +37,10 @@ import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; <#include "/@includes/vv_imports.ftl" /> +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + /* * This class is generated using freemarker and the ${.template_name} template. */ @@ -103,55 +107,31 @@ public void setPosition(int index) { super.setPosition(index); } - <#list vv.types as type><#list type.minor as minor><#assign name = minor.class?cap_first /> - <#assign fields = minor.fields!type.fields /> - <#assign uncappedName = name?uncap_first/> - <#if uncappedName == "int" ><#assign uncappedName = "integer" /> - <#if !minor.typeParams?? > - + <#list vv.types as type><#list type.minor as minor> + <#assign lowerName = minor.class?uncap_first /> + <#if lowerName == "int" ><#assign lowerName = "integer" /> + <#assign upperName = minor.class?upper_case /> + <#assign capName = minor.class?cap_first /> + <#assign vectName = capName /> @Override - public ${name}Writer ${uncappedName}() { + public ${minor.class}Writer ${lowerName}() { return this; } + <#if minor.typeParams?? > @Override - public ${name}Writer ${uncappedName}(String name) { - structName = name; - return writer.${uncappedName}(name); + public ${minor.class}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}) { + return writer.${lowerName}(name<#list minor.typeParams as typeParam>, ${typeParam.name}); } - - - @Override - public DecimalWriter decimal() { - return this; - } - - @Override - public DecimalWriter decimal(String name, int scale, int precision) { - return writer.decimal(name, scale, precision); - } - - @Override - public DecimalWriter decimal(String name) { - return writer.decimal(name); - } - - @Override - public Decimal256Writer decimal256() { - return this; - } @Override - public Decimal256Writer decimal256(String name, int scale, int precision) { - return writer.decimal256(name, scale, precision); - } - - @Override - public Decimal256Writer decimal256(String name) { - return writer.decimal256(name); + public ${minor.class}Writer ${lowerName}(String name) { + structName = name; + return writer.${lowerName}(name); } + @Override public StructWriter struct() { @@ -240,18 +220,6 @@ public void end() { inStruct = false; } - @Override - public void write(DecimalHolder holder) { - writer.write(holder); - writer.setPosition(writer.idx()+1); - } - - @Override - public void write(Decimal256Holder holder) { - writer.write(holder); - writer.setPosition(writer.idx()+1); - } - @Override public void writeNull() { if (!listStarted){ @@ -261,65 +229,53 @@ public void writeNull() { } } - public void writeDecimal(long start, ArrowBuf buffer, ArrowType arrowType) { - writer.writeDecimal(start, buffer, arrowType); - writer.setPosition(writer.idx()+1); - } - - public void writeDecimal(long start, ArrowBuf buffer) { - writer.writeDecimal(start, buffer); + <#list vv.types as type> + <#list type.minor as minor> + <#assign name = minor.class?cap_first /> + <#assign fields = minor.fields!type.fields /> + <#assign uncappedName = name?uncap_first/> + @Override + public void write${name}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { + writer.write${name}(<#list fields as field>${field.name}<#if field_has_next>, ); writer.setPosition(writer.idx()+1); } - public void writeDecimal(BigDecimal value) { - writer.writeDecimal(value); + <#if is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + @Override + public void write(${name}Holder holder) { + writer.write(holder); writer.setPosition(writer.idx()+1); } - public void writeBigEndianBytesToDecimal(byte[] value, ArrowType arrowType){ - writer.writeBigEndianBytesToDecimal(value, arrowType); - writer.setPosition(writer.idx() + 1); - } - - public void writeDecimal256(long start, ArrowBuf buffer, ArrowType arrowType) { - writer.writeDecimal256(start, buffer, arrowType); + <#elseif minor.class?starts_with("Decimal")> + public void write${name}(long start, ArrowBuf buffer, ArrowType arrowType) { + writer.write${name}(start, buffer, arrowType); writer.setPosition(writer.idx()+1); } - public void writeDecimal256(long start, ArrowBuf buffer) { - writer.writeDecimal256(start, buffer); + @Override + public void write(${name}Holder holder) { + writer.write(holder); writer.setPosition(writer.idx()+1); } - public void writeDecimal256(BigDecimal value) { - writer.writeDecimal256(value); + public void write${name}(BigDecimal value) { + writer.write${name}(value); writer.setPosition(writer.idx()+1); } - public void writeBigEndianBytesToDecimal256(byte[] value, ArrowType arrowType){ - writer.writeBigEndianBytesToDecimal256(value, arrowType); + public void writeBigEndianBytesTo${name}(byte[] value, ArrowType arrowType){ + writer.writeBigEndianBytesTo${name}(value, arrowType); writer.setPosition(writer.idx() + 1); } - - - <#list vv.types as type> - <#list type.minor as minor> - <#assign name = minor.class?cap_first /> - <#assign fields = minor.fields!type.fields /> - <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? > + <#else> @Override - public void write${name}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, ) { - writer.write${name}(<#list fields as field>${field.name}<#if field_has_next>, ); - writer.setPosition(writer.idx()+1); - } - public void write(${name}Holder holder) { writer.write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, ); writer.setPosition(writer.idx()+1); } + - } diff --git a/java/vector/src/main/codegen/templates/UnionReader.java b/java/vector/src/main/codegen/templates/UnionReader.java index 444ca9ca734c..56a6cc90b321 100644 --- a/java/vector/src/main/codegen/templates/UnionReader.java +++ b/java/vector/src/main/codegen/templates/UnionReader.java @@ -28,6 +28,11 @@ package org.apache.arrow.vector.complex.impl; <#include "/@includes/vv_imports.ftl" /> + +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + /** * Source code generated using FreeMarker template ${.template_name} */ @@ -90,7 +95,7 @@ private FieldReader getReaderForIndex(int index) { <#list type.minor as minor> <#assign name = minor.class?cap_first /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> case ${name?upper_case}: return (FieldReader) get${name}(); @@ -170,7 +175,7 @@ public int size() { <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> <#assign safeType=friendlyType /> <#if safeType=="byte[]"><#assign safeType="ByteArray" /> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> private ${name}ReaderImpl ${uncappedName}Reader; diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index 48fa5281ea13..0446faab7a95 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -67,7 +67,9 @@ import static org.apache.arrow.memory.util.LargeMemoryUtil.checkedCastToInt; import static org.apache.arrow.memory.util.LargeMemoryUtil.capAtMaxInt; - +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + /* * This class is generated using freemarker and the ${.template_name} template. @@ -269,18 +271,24 @@ public StructVector getStruct() { <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> <#assign lowerCaseName = name?lower_case/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> private ${name}Vector ${uncappedName}Vector; - public ${name}Vector get${name}Vector(<#if minor.class?starts_with("Decimal")> ArrowType arrowType) { - return get${name}Vector(null<#if minor.class?starts_with("Decimal")>, arrowType); + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + public ${name}Vector get${name}Vector() { + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + } + return ${uncappedName}Vector; } - - public ${name}Vector get${name}Vector(String name<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { + public ${name}Vector get${name}Vector(ArrowType arrowType) { + return get${name}Vector(null, arrowType); + } + public ${name}Vector get${name}Vector(String name, ArrowType arrowType) { if (${uncappedName}Vector == null) { int vectorCount = internalStruct.size(); - ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case},<#if minor.class?starts_with("Decimal")> arrowType, ${name}Vector.class); + ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, arrowType, ${name}Vector.class); if (internalStruct.size() > vectorCount) { ${uncappedName}Vector.allocateNew(); if (callBack != null) { @@ -290,10 +298,21 @@ public StructVector getStruct() { } return ${uncappedName}Vector; } - <#if minor.class?starts_with("Decimal")> + <#else> public ${name}Vector get${name}Vector() { + return get${name}Vector(null); + } + + public ${name}Vector get${name}Vector(String name) { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${uncappedName} present. Provide ArrowType argument to create a new vector"); + int vectorCount = internalStruct.size(); + ${uncappedName}Vector = addOrGet(name, MinorType.${name?upper_case}, ${name}Vector.class); + if (internalStruct.size() > vectorCount) { + ${uncappedName}Vector.allocateNew(); + if (callBack != null) { + callBack.doWork(); + } + } } return ${uncappedName}Vector; } @@ -658,9 +677,9 @@ public ValueVector getVectorByType(int typeId, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> case ${name?upper_case}: - return get${name}Vector(name<#if minor.class?starts_with("Decimal")>, arrowType); + return get${name}Vector(name<#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary">, arrowType); @@ -745,11 +764,15 @@ public void setSafe(int index, UnionHolder holder, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> case ${name?upper_case}: Nullable${name}Holder ${uncappedName}Holder = new Nullable${name}Holder(); reader.read(${uncappedName}Holder); - setSafe(index, ${uncappedName}Holder<#if minor.class?starts_with("Decimal")>, arrowType); + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + setSafe(index, ${uncappedName}Holder, arrowType); + <#else> + setSafe(index, ${uncappedName}Holder); + break; @@ -766,17 +789,24 @@ public void setSafe(int index, UnionHolder holder, ArrowType arrowType) { throw new UnsupportedOperationException(); } } + <#list vv.types as type> <#list type.minor as minor> <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > - public void setSafe(int index, Nullable${name}Holder holder<#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + public void setSafe(int index, Nullable${name}Holder holder, ArrowType arrowType) { setType(index, MinorType.${name?upper_case}); - get${name}Vector(null<#if minor.class?starts_with("Decimal")>, arrowType).setSafe(index, holder); + get${name}Vector(null, arrowType).setSafe(index, holder); } - + <#else> + public void setSafe(int index, Nullable${name}Holder holder) { + setType(index, MinorType.${name?upper_case}); + get${name}Vector(null).setSafe(index, holder); + } + diff --git a/java/vector/src/main/codegen/templates/UnionWriter.java b/java/vector/src/main/codegen/templates/UnionWriter.java index fc4fd7dd798e..4efd1026cac4 100644 --- a/java/vector/src/main/codegen/templates/UnionWriter.java +++ b/java/vector/src/main/codegen/templates/UnionWriter.java @@ -31,6 +31,11 @@ import org.apache.arrow.vector.complex.writer.BaseWriter; import org.apache.arrow.vector.types.Types.MinorType; +<#function is_timestamp_tz type> + <#return type?starts_with("TimeStamp") && type?ends_with("TZ")> + + + /* * This class is generated using freemarker and the ${.template_name} template. */ @@ -183,9 +188,13 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal")> + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> case ${name?upper_case}: - return get${name}Writer(<#if minor.class?starts_with("Decimal") >arrowType); + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + return get${name}Writer(arrowType); + <#else> + return get${name}Writer(); + @@ -199,36 +208,86 @@ BaseWriter getWriter(MinorType minorType, ArrowType arrowType) { <#assign fields = minor.fields!type.fields /> <#assign uncappedName = name?uncap_first/> <#assign friendlyType = (minor.friendlyType!minor.boxedType!type.boxedType) /> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> private ${name}Writer ${name?uncap_first}Writer; - private ${name}Writer get${name}Writer(<#if minor.class?starts_with("Decimal")>ArrowType arrowType) { + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> + private ${name}Writer get${name}Writer(ArrowType arrowType) { if (${uncappedName}Writer == null) { - ${uncappedName}Writer = new ${name}WriterImpl(data.get${name}Vector(<#if minor.class?starts_with("Decimal")>arrowType)); + ${uncappedName}Writer = new ${name}WriterImpl(data.get${name}Vector(arrowType)); ${uncappedName}Writer.setPosition(idx()); writers.add(${uncappedName}Writer); } return ${uncappedName}Writer; } - public ${name}Writer as${name}(<#if minor.class?starts_with("Decimal")>ArrowType arrowType) { + public ${name}Writer as${name}(ArrowType arrowType) { data.setType(idx(), MinorType.${name?upper_case}); - return get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType); + return get${name}Writer(arrowType); + } + <#else> + private ${name}Writer get${name}Writer() { + if (${uncappedName}Writer == null) { + ${uncappedName}Writer = new ${name}WriterImpl(data.get${name}Vector()); + ${uncappedName}Writer.setPosition(idx()); + writers.add(${uncappedName}Writer); + } + return ${uncappedName}Writer; } + public ${name}Writer as${name}() { + data.setType(idx(), MinorType.${name?upper_case}); + return get${name}Writer(); + } + + @Override public void write(${name}Holder holder) { data.setType(idx(), MinorType.${name?upper_case}); - <#if minor.class?starts_with("Decimal")>ArrowType arrowType = new ArrowType.Decimal(holder.precision, holder.scale, ${name}Holder.WIDTH * 8); - get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).setPosition(idx()); - get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, arrowType); + <#if minor.class?starts_with("Decimal")> + ArrowType arrowType = new ArrowType.Decimal(holder.precision, holder.scale, ${name}Holder.WIDTH * 8); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, , arrowType); + <#elseif is_timestamp_tz(minor.class)> + ArrowType.Timestamp arrowTypeWithoutTz = (ArrowType.Timestamp) MinorType.${name?upper_case?remove_ending("TZ")}.getType(); + ArrowType arrowType = new ArrowType.Timestamp(arrowTypeWithoutTz.getUnit(), holder.timezone); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write(holder); + <#elseif minor.class == "Duration"> + ArrowType arrowType = new ArrowType.Duration(holder.unit); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write(holder); + <#elseif minor.class == "FixedSizeBinary"> + ArrowType arrowType = new ArrowType.FixedSizeBinary(holder.byteWidth); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write(holder); + <#else> + get${name}Writer().setPosition(idx()); + get${name}Writer().write${name}(<#list fields as field>holder.${field.name}<#if field_has_next>, ); + } public void write${minor.class}(<#list fields as field>${field.type} ${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, ArrowType arrowType) { data.setType(idx(), MinorType.${name?upper_case}); - get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).setPosition(idx()); - get${name}Writer(<#if minor.class?starts_with("Decimal")>arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, <#if minor.class?starts_with("Decimal")>, arrowType); + <#if minor.class?starts_with("Decimal")> + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, , arrowType); + <#elseif is_timestamp_tz(minor.class)> + ArrowType.Timestamp arrowTypeWithoutTz = (ArrowType.Timestamp) MinorType.${name?upper_case?remove_ending("TZ")}.getType(); + ArrowType arrowType = new ArrowType.Timestamp(arrowTypeWithoutTz.getUnit(), "UTC"); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, ); + <#elseif minor.class == "Duration" || minor.class == "FixedSizeBinary"> + // This is expected to throw. There's nothing more that we can do here since we can't infer any + // sort of default unit for the Duration or a default width for the FixedSizeBinary types. + ArrowType arrowType = MinorType.${name?upper_case}.getType(); + get${name}Writer(arrowType).setPosition(idx()); + get${name}Writer(arrowType).write${name}(<#list fields as field>${field.name}<#if field_has_next>, ); + <#else> + get${name}Writer().setPosition(idx()); + get${name}Writer().write${name}(<#list fields as field>${field.name}<#if field_has_next>, ); + } <#if minor.class?starts_with("Decimal")> public void write${name}(${friendlyType} value) { @@ -312,7 +371,7 @@ public MapWriter map(String name, boolean keysSorted) { <#if lowerName == "int" ><#assign lowerName = "integer" /> <#assign upperName = minor.class?upper_case /> <#assign capName = minor.class?cap_first /> - <#if !minor.typeParams?? || minor.class?starts_with("Decimal") > + <#if !minor.typeParams?? || minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> @Override public ${capName}Writer ${lowerName}(String name) { data.setType(idx(), MinorType.STRUCT); @@ -327,7 +386,7 @@ public MapWriter map(String name, boolean keysSorted) { return getListWriter().${lowerName}(); } - <#if minor.class?starts_with("Decimal")> + <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> @Override public ${capName}Writer ${lowerName}(String name<#list minor.typeParams as typeParam>, ${typeParam.type} ${typeParam.name}) { data.setType(idx(), MinorType.STRUCT); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java index 9671b34e0027..7af21a8ecdc9 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DurationVector.java @@ -141,6 +141,7 @@ public void get(int index, NullableDurationHolder holder) { } holder.isSet = 1; holder.value = get(valueBuffer, index); + holder.unit = this.unit; } /** @@ -241,6 +242,9 @@ public void set(int index, long value) { public void set(int index, NullableDurationHolder holder) throws IllegalArgumentException { if (holder.isSet < 0) { throw new IllegalArgumentException(); + } else if (!this.unit.equals(holder.unit)) { + throw new IllegalArgumentException( + String.format("holder.unit: %s not equal to vector unit: %s", holder.unit, this.unit)); } else if (holder.isSet > 0) { set(index, holder.value); } else { @@ -255,6 +259,10 @@ public void set(int index, NullableDurationHolder holder) throws IllegalArgument * @param holder data holder for value of element */ public void set(int index, DurationHolder holder) { + if (!this.unit.equals(holder.unit)) { + throw new IllegalArgumentException( + String.format("holder.unit: %s not equal to vector unit: %s", holder.unit, this.unit)); + } set(index, holder.value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java index e1847e4bb944..f9ea37e4c376 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/FixedSizeBinaryVector.java @@ -138,6 +138,7 @@ public void get(int index, NullableFixedSizeBinaryHolder holder) { } holder.isSet = 1; holder.buffer = valueBuffer.slice((long) index * byteWidth, byteWidth); + holder.byteWidth = byteWidth; } /** @@ -257,7 +258,10 @@ public void setSafe(int index, int isSet, ArrowBuf buffer) { * @param holder holder that carries data buffer. */ public void set(int index, FixedSizeBinaryHolder holder) { - assert holder.byteWidth == byteWidth; + if (this.byteWidth != holder.byteWidth) { + throw new IllegalArgumentException( + String.format("holder.byteWidth: %d not equal to vector byteWidth: %d", holder.byteWidth, this.byteWidth)); + } set(index, holder.buffer); } @@ -282,9 +286,11 @@ public void setSafe(int index, FixedSizeBinaryHolder holder) { * @param holder holder that carries data buffer. */ public void set(int index, NullableFixedSizeBinaryHolder holder) { - assert holder.byteWidth == byteWidth; if (holder.isSet < 0) { throw new IllegalArgumentException("holder has a negative isSet value"); + } else if (this.byteWidth != holder.byteWidth) { + throw new IllegalArgumentException( + String.format("holder.byteWidth: %d not equal to vector byteWidth: %d", holder.byteWidth, this.byteWidth)); } else if (holder.isSet > 0) { set(index, holder.buffer); } else { diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java index d08b3523067b..e083392ffe56 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMicroTZVector.java @@ -132,6 +132,7 @@ public void get(int index, NullableTimeStampMicroTZHolder holder) { } holder.isSet = 1; holder.value = valueBuffer.getLong((long) index * TYPE_WIDTH); + holder.timezone = timeZone; } /** @@ -167,6 +168,9 @@ public Long getObject(int index) { public void set(int index, NullableTimeStampMicroTZHolder holder) throws IllegalArgumentException { if (holder.isSet < 0) { throw new IllegalArgumentException(); + } else if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); } else if (holder.isSet > 0) { BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); @@ -182,6 +186,10 @@ public void set(int index, NullableTimeStampMicroTZHolder holder) throws Illegal * @param holder data holder for value of element */ public void set(int index, TimeStampMicroTZHolder holder) { + if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); + } BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java index 1151d064e255..d01a43aa1b63 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampMilliTZVector.java @@ -132,6 +132,7 @@ public void get(int index, NullableTimeStampMilliTZHolder holder) { } holder.isSet = 1; holder.value = valueBuffer.getLong((long) index * TYPE_WIDTH); + holder.timezone = timeZone; } /** @@ -167,6 +168,9 @@ public Long getObject(int index) { public void set(int index, NullableTimeStampMilliTZHolder holder) throws IllegalArgumentException { if (holder.isSet < 0) { throw new IllegalArgumentException(); + } else if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); } else if (holder.isSet > 0) { BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); @@ -182,6 +186,10 @@ public void set(int index, NullableTimeStampMilliTZHolder holder) throws Illegal * @param holder data holder for value of element */ public void set(int index, TimeStampMilliTZHolder holder) { + if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); + } BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java index b19b437781fd..2a51babda16f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampNanoTZVector.java @@ -132,6 +132,7 @@ public void get(int index, NullableTimeStampNanoTZHolder holder) { } holder.isSet = 1; holder.value = valueBuffer.getLong((long) index * TYPE_WIDTH); + holder.timezone = timeZone; } /** @@ -167,6 +168,9 @@ public Long getObject(int index) { public void set(int index, NullableTimeStampNanoTZHolder holder) throws IllegalArgumentException { if (holder.isSet < 0) { throw new IllegalArgumentException(); + } else if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); } else if (holder.isSet > 0) { BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); @@ -182,6 +186,10 @@ public void set(int index, NullableTimeStampNanoTZHolder holder) throws IllegalA * @param holder data holder for value of element */ public void set(int index, TimeStampNanoTZHolder holder) { + if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); + } BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecTZVector.java b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecTZVector.java index 1ffdb55c7a56..47e796e9951f 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecTZVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/TimeStampSecTZVector.java @@ -132,6 +132,7 @@ public void get(int index, NullableTimeStampSecTZHolder holder) { } holder.isSet = 1; holder.value = valueBuffer.getLong((long) index * TYPE_WIDTH); + holder.timezone = timeZone; } /** @@ -167,6 +168,9 @@ public Long getObject(int index) { public void set(int index, NullableTimeStampSecTZHolder holder) throws IllegalArgumentException { if (holder.isSet < 0) { throw new IllegalArgumentException(); + } else if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); } else if (holder.isSet > 0) { BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); @@ -182,6 +186,10 @@ public void set(int index, NullableTimeStampSecTZHolder holder) throws IllegalAr * @param holder data holder for value of element */ public void set(int index, TimeStampSecTZHolder holder) { + if (!this.timeZone.equals(holder.timezone)) { + throw new IllegalArgumentException( + String.format("holder.timezone: %s not equal to vector timezone: %s", holder.timezone, this.timeZone)); + } BitVectorHelper.setBit(validityBuffer, index); setValue(index, holder.value); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java index 06b064fdaac5..d99efceae3ec 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/impl/PromotableWriter.java @@ -247,10 +247,18 @@ public void setPosition(int index) { } } + private boolean requiresArrowType(MinorType type) { + return type == MinorType.DECIMAL || + type == MinorType.MAP || + type == MinorType.DURATION || + type == MinorType.FIXEDSIZEBINARY || + (type.name().startsWith("TIMESTAMP") && type.name().endsWith("TZ")); + } + @Override protected FieldWriter getWriter(MinorType type, ArrowType arrowType) { if (state == State.UNION) { - if (type == MinorType.DECIMAL || type == MinorType.MAP) { + if (requiresArrowType(type)) { ((UnionWriter) writer).getWriter(type, arrowType); } else { ((UnionWriter) writer).getWriter(type); @@ -277,7 +285,7 @@ protected FieldWriter getWriter(MinorType type, ArrowType arrowType) { writer.setPosition(position); } else if (type != this.type) { promoteToUnion(); - if (type == MinorType.DECIMAL || type == MinorType.MAP) { + if (requiresArrowType(type)) { ((UnionWriter) writer).getWriter(type, arrowType); } else { ((UnionWriter) writer).getWriter(type); diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeBinaryVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeBinaryVector.java index 363821e98397..e8f764a21e96 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeBinaryVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestFixedSizeBinaryVector.java @@ -207,25 +207,25 @@ public void testSetWithInvalidInput() throws Exception { try { vector.set(0, smallValue); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.set(0, smallHolder); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.set(0, smallNullableHolder); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.set(0, smallBuf); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } // test large inputs, byteWidth matches but value or buffer is bigger than byteWidth @@ -243,25 +243,25 @@ public void setSetSafeWithInvalidInput() throws Exception { try { vector.setSafe(0, smallValue); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.setSafe(0, smallHolder); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.setSafe(0, smallNullableHolder); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } try { vector.setSafe(0, smallBuf); failWithException(errorMsg); - } catch (AssertionError ignore) { + } catch (AssertionError | IllegalArgumentException ignore) { } // test large inputs, byteWidth matches but value or buffer is bigger than byteWidth diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index ffeedf04d033..f0f19058eef2 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.ArrayList; import java.util.Arrays; @@ -28,10 +29,15 @@ import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.AutoCloseables; import org.apache.arrow.vector.complex.BaseRepeatedValueVector; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.impl.UnionListWriter; import org.apache.arrow.vector.complex.reader.FieldReader; +import org.apache.arrow.vector.holders.DurationHolder; +import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types.MinorType; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -898,6 +904,144 @@ public void testWriterGetField() { } } + @Test + public void testWriterGetTimestampMilliTZField() { + try (final ListVector vector = ListVector.empty("list", allocator)) { + org.apache.arrow.vector.complex.writer.FieldWriter writer = vector.getWriter(); + writer.allocate(); + + writer.startList(); + writer.timeStampMilliTZ().writeTimeStampMilliTZ(1000L); + writer.timeStampMilliTZ().writeTimeStampMilliTZ(2000L); + writer.endList(); + vector.setValueCount(1); + + Field expectedDataField = new Field(BaseRepeatedValueVector.DATA_VECTOR_NAME, + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC")), null); + Field expectedField = new Field(vector.getName(), FieldType.nullable(ArrowType.List.INSTANCE), + Arrays.asList(expectedDataField)); + + assertEquals(expectedField, writer.getField()); + } + } + + @Test + public void testWriterUsingHolderGetTimestampMilliTZField() { + try (final ListVector vector = ListVector.empty("list", allocator)) { + org.apache.arrow.vector.complex.writer.FieldWriter writer = vector.getWriter(); + writer.allocate(); + + TimeStampMilliTZHolder holder = new TimeStampMilliTZHolder(); + holder.timezone = "SomeFakeTimeZone"; + writer.startList(); + holder.value = 12341234L; + writer.timeStampMilliTZ().write(holder); + holder.value = 55555L; + writer.timeStampMilliTZ().write(holder); + + // Writing with a different timezone should throw + holder.timezone = "AsdfTimeZone"; + holder.value = 77777; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.timeStampMilliTZ().write(holder)); + assertEquals("holder.timezone: AsdfTimeZone not equal to vector timezone: SomeFakeTimeZone", ex.getMessage()); + + writer.endList(); + vector.setValueCount(1); + + Field expectedDataField = new Field(BaseRepeatedValueVector.DATA_VECTOR_NAME, + FieldType.nullable(new ArrowType.Timestamp(TimeUnit.MILLISECOND, "SomeFakeTimeZone")), null); + Field expectedField = new Field(vector.getName(), FieldType.nullable(ArrowType.List.INSTANCE), + Arrays.asList(expectedDataField)); + + assertEquals(expectedField, writer.getField()); + } + } + + @Test + public void testWriterGetDurationField() { + try (final ListVector vector = ListVector.empty("list", allocator)) { + org.apache.arrow.vector.complex.writer.FieldWriter writer = vector.getWriter(); + writer.allocate(); + + DurationHolder durationHolder = new DurationHolder(); + durationHolder.unit = TimeUnit.MILLISECOND; + + writer.startList(); + durationHolder.value = 812374L; + writer.duration().write(durationHolder); + durationHolder.value = 143451L; + writer.duration().write(durationHolder); + + // Writing with a different unit should throw + durationHolder.unit = TimeUnit.SECOND; + durationHolder.value = 8888888; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.duration().write(durationHolder)); + assertEquals("holder.unit: SECOND not equal to vector unit: MILLISECOND", ex.getMessage()); + + writer.endList(); + vector.setValueCount(1); + + Field expectedDataField = new Field(BaseRepeatedValueVector.DATA_VECTOR_NAME, + FieldType.nullable(new ArrowType.Duration(TimeUnit.MILLISECOND)), null); + Field expectedField = new Field(vector.getName(), FieldType.nullable(ArrowType.List.INSTANCE), + Arrays.asList(expectedDataField)); + + assertEquals(expectedField, writer.getField()); + } + } + + @Test + public void testWriterGetFixedSizeBinaryField() throws Exception { + // Adapted from: TestComplexWriter.java:fixedSizeBinaryWriters + // test values + int numValues = 10; + int byteWidth = 9; + byte[][] values = new byte[numValues][byteWidth]; + for (int i = 0; i < numValues; i++) { + for (int j = 0; j < byteWidth; j++) { + values[i][j] = ((byte) i); + } + } + ArrowBuf[] bufs = new ArrowBuf[numValues]; + for (int i = 0; i < numValues; i++) { + bufs[i] = allocator.buffer(byteWidth); + bufs[i].setBytes(0, values[i]); + } + + try (final ListVector vector = ListVector.empty("list", allocator)) { + org.apache.arrow.vector.complex.writer.FieldWriter writer = vector.getWriter(); + writer.allocate(); + + FixedSizeBinaryHolder binHolder = new FixedSizeBinaryHolder(); + binHolder.byteWidth = byteWidth; + writer.startList(); + for (int i = 0; i < numValues; i++) { + binHolder.buffer = bufs[i]; + writer.fixedSizeBinary().write(binHolder); + } + + // Writing with a different byteWidth should throw + // Note just reusing the last buffer value since that won't matter here anyway + binHolder.byteWidth = 3; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.fixedSizeBinary().write(binHolder)); + assertEquals("holder.byteWidth: 3 not equal to vector byteWidth: 9", ex.getMessage()); + + writer.endList(); + vector.setValueCount(1); + + Field expectedDataField = new Field(BaseRepeatedValueVector.DATA_VECTOR_NAME, + FieldType.nullable(new ArrowType.FixedSizeBinary(byteWidth)), null); + Field expectedField = new Field(vector.getName(), FieldType.nullable(ArrowType.List.INSTANCE), + Arrays.asList(expectedDataField)); + + assertEquals(expectedField, writer.getField()); + } + AutoCloseables.close(bufs); + } + @Test public void testClose() throws Exception { try (final ListVector vector = ListVector.empty("list", allocator)) { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java index f04998915b64..b53171a59768 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; +import static org.junit.jupiter.api.Assertions.assertThrows; import java.util.ArrayList; import java.util.HashMap; @@ -497,6 +498,28 @@ public void testSetGetNull() { } } + @Test + public void testCreateNewVectorWithoutTypeExceptionThrown() { + try (UnionVector vector = + new UnionVector(EMPTY_SCHEMA_PATH, allocator, /* field type */ null, /* call-back */ null)) { + IllegalArgumentException e1 = assertThrows(IllegalArgumentException.class, + () -> vector.getTimeStampMilliTZVector()); + assertEquals("No TimeStampMilliTZ present. Provide ArrowType argument to create a new vector", e1.getMessage()); + + IllegalArgumentException e2 = assertThrows(IllegalArgumentException.class, + () -> vector.getDurationVector()); + assertEquals("No Duration present. Provide ArrowType argument to create a new vector", e2.getMessage()); + + IllegalArgumentException e3 = assertThrows(IllegalArgumentException.class, + () -> vector.getFixedSizeBinaryVector()); + assertEquals("No FixedSizeBinary present. Provide ArrowType argument to create a new vector", e3.getMessage()); + + IllegalArgumentException e4 = assertThrows(IllegalArgumentException.class, + () -> vector.getDecimalVector()); + assertEquals("No Decimal present. Provide ArrowType argument to create a new vector", e4.getMessage()); + } + } + private static NullableIntHolder newIntHolder(int value) { final NullableIntHolder holder = new NullableIntHolder(); holder.isSet = 1; diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 9dce33122e88..1068f7c030eb 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -20,7 +20,12 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.complex.ListVector; @@ -28,7 +33,13 @@ import org.apache.arrow.vector.complex.StructVector; import org.apache.arrow.vector.complex.UnionVector; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; +import org.apache.arrow.vector.holders.DurationHolder; +import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; @@ -78,9 +89,35 @@ public void testPromoteToUnion() throws Exception { writer.setPosition(4); writer.integer("A").writeInt(100); + writer.setPosition(5); + writer.timeStampMilliTZ("A").writeTimeStampMilliTZ(123123); + + // Also try the holder version for timeStampMilliTZ + writer.setPosition(6); + TimeStampMilliTZHolder tsmtzHolder = new TimeStampMilliTZHolder(); + // This has to be UTC since the vector above was initialized using the non holder + // version that defaults to UTC. + tsmtzHolder.timezone = "UTC"; + tsmtzHolder.value = 12345L; + writer.timeStampMilliTZ("A").write(tsmtzHolder); + + writer.setPosition(7); + DurationHolder durationHolder = new DurationHolder(); + durationHolder.unit = TimeUnit.SECOND; + durationHolder.value = 444413; + writer.duration("A").write(durationHolder); + + writer.setPosition(8); + ArrowBuf buf = allocator.buffer(4); + buf.setInt(0, 18978); + FixedSizeBinaryHolder binHolder = new FixedSizeBinaryHolder(); + binHolder.byteWidth = 4; + binHolder.buffer = buf; + writer.fixedSizeBinary("A", 4).write(binHolder); + writer.end(); - container.setValueCount(5); + container.setValueCount(9); final UnionVector uv = v.getChild("A", UnionVector.class); @@ -98,6 +135,22 @@ public void testPromoteToUnion() throws Exception { assertFalse("4 shouldn't be null", uv.isNull(4)); assertEquals(100, uv.getObject(4)); + assertFalse("5 shouldn't be null", uv.isNull(5)); + assertEquals(123123L, uv.getObject(5)); + + assertFalse("6 shouldn't be null", uv.isNull(6)); + NullableTimeStampMilliTZHolder readBackHolder = new NullableTimeStampMilliTZHolder(); + uv.getTimeStampMilliTZVector().get(6, readBackHolder); + assertEquals(12345L, readBackHolder.value); + assertEquals("UTC", readBackHolder.timezone); + + assertFalse("7 shouldn't be null", uv.isNull(7)); + assertEquals(444413L, ((java.time.Duration) uv.getObject(7)).getSeconds()); + + assertFalse("8 shouldn't be null", uv.isNull(8)); + assertEquals(18978, + ByteBuffer.wrap(uv.getFixedSizeBinaryVector().get(8)).order(ByteOrder.nativeOrder()).getInt()); + container.clear(); container.allocateNew(); @@ -116,11 +169,13 @@ public void testPromoteToUnion() throws Exception { childField1.getName(), ArrowTypeID.Union, childField1.getType().getTypeID()); assertEquals("Child field should be decimal type: " + childField2.getName(), ArrowTypeID.Decimal, childField2.getType().getTypeID()); + + buf.close(); } } @Test - public void testNoPromoteToUnionWithNull() throws Exception { + public void testNoPromoteFloat4ToUnionWithNull() throws Exception { try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); final StructVector v = container.addOrGetStruct("test"); @@ -136,7 +191,6 @@ public void testNoPromoteToUnionWithNull() throws Exception { FieldType childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) .getChildren().get(0).getFieldType(); - // create a listvector with same type as list in container to, say, hold a copy // this will be a nullvector ListVector lv = ListVector.empty("name", allocator); @@ -164,4 +218,178 @@ public void testNoPromoteToUnionWithNull() throws Exception { lv.close(); } } + + @Test + public void testNoPromoteTimeStampMilliTZToUnionWithNull() throws Exception { + + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final StructVector v = container.addOrGetStruct("test"); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + + writer.start(); + writer.list("list").startList(); + writer.list("list").endList(); + writer.end(); + + FieldType childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // create a listvector with same type as list in container to, say, hold a copy + // this will be a nullvector + ListVector lv = ListVector.empty("name", allocator); + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), Types.MinorType.NULL.getType()); + assertEquals(lv.getChildrenFromFields().get(0).getMinorType().getType(), Types.MinorType.NULL.getType()); + + writer.start(); + writer.list("list").startList(); + TimeStampMilliTZHolder holder = new TimeStampMilliTZHolder(); + holder.value = 12341234L; + holder.timezone = "FakeTimeZone"; + writer.list("list").timeStampMilliTZ().write(holder); + + // Test that we get an exception when the timezone doesn't match + holder.timezone = "SomeTimeZone"; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.list("list").timeStampMilliTZ().write(holder)); + assertEquals("holder.timezone: SomeTimeZone not equal to vector timezone: FakeTimeZone", ex.getMessage()); + + writer.list("list").endList(); + writer.end(); + + container.setValueCount(2); + + childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // repeat but now the type in container has been changed from null to float + // we expect same behaviour from listvector + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), + new ArrowType.Timestamp(TimeUnit.MILLISECOND, "FakeTimeZone")); + assertEquals(lv.getChildrenFromFields().get(0).getField().getType(), + new ArrowType.Timestamp(TimeUnit.MILLISECOND, "FakeTimeZone")); + + lv.close(); + } + } + + @Test + public void testNoPromoteDurationToUnionWithNull() throws Exception { + + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final StructVector v = container.addOrGetStruct("test"); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + + writer.start(); + writer.list("list").startList(); + writer.list("list").endList(); + writer.end(); + + FieldType childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // create a listvector with same type as list in container to, say, hold a copy + // this will be a nullvector + ListVector lv = ListVector.empty("name", allocator); + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), Types.MinorType.NULL.getType()); + assertEquals(lv.getChildrenFromFields().get(0).getMinorType().getType(), Types.MinorType.NULL.getType()); + + writer.start(); + writer.list("list").startList(); + DurationHolder holder = new DurationHolder(); + holder.unit = TimeUnit.NANOSECOND; + holder.value = 567657L; + writer.list("list").duration().write(holder); + + // Test that we get an exception when the unit doesn't match + holder.unit = TimeUnit.MICROSECOND; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.list("list").duration().write(holder)); + assertEquals("holder.unit: MICROSECOND not equal to vector unit: NANOSECOND", ex.getMessage()); + + writer.list("list").endList(); + writer.end(); + + container.setValueCount(2); + + childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // repeat but now the type in container has been changed from null to float + // we expect same behaviour from listvector + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), + new ArrowType.Duration(TimeUnit.NANOSECOND)); + assertEquals(lv.getChildrenFromFields().get(0).getField().getType(), + new ArrowType.Duration(TimeUnit.NANOSECOND)); + + lv.close(); + } + } + + @Test + public void testNoPromoteFixedSizeBinaryToUnionWithNull() throws Exception { + + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final StructVector v = container.addOrGetStruct("test"); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + + writer.start(); + writer.list("list").startList(); + writer.list("list").endList(); + writer.end(); + + FieldType childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // create a listvector with same type as list in container to, say, hold a copy + // this will be a nullvector + ListVector lv = ListVector.empty("name", allocator); + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), Types.MinorType.NULL.getType()); + assertEquals(lv.getChildrenFromFields().get(0).getMinorType().getType(), Types.MinorType.NULL.getType()); + + writer.start(); + writer.list("list").startList(); + ArrowBuf buf = allocator.buffer(4); + buf.setInt(0, 22222); + FixedSizeBinaryHolder holder = new FixedSizeBinaryHolder(); + holder.byteWidth = 4; + holder.buffer = buf; + writer.list("list").fixedSizeBinary().write(holder); + + // Test that we get an exception when the unit doesn't match + holder.byteWidth = 7; + IllegalArgumentException ex = assertThrows(IllegalArgumentException.class, + () -> writer.list("list").fixedSizeBinary().write(holder)); + assertEquals("holder.byteWidth: 7 not equal to vector byteWidth: 4", ex.getMessage()); + + writer.list("list").endList(); + writer.end(); + + container.setValueCount(2); + + childTypeOfListInContainer = container.getField().getChildren().get(0).getChildren().get(0) + .getChildren().get(0).getFieldType(); + + // repeat but now the type in container has been changed from null to float + // we expect same behaviour from listvector + lv.addOrGetVector(childTypeOfListInContainer); + assertEquals(childTypeOfListInContainer.getType(), + new ArrowType.FixedSizeBinary(4)); + assertEquals(lv.getChildrenFromFields().get(0).getField().getType(), + new ArrowType.FixedSizeBinary(4)); + + lv.close(); + buf.close(); + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java index 55041496653a..9f7f66083ac6 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/writer/TestComplexWriter.java @@ -21,6 +21,7 @@ import java.math.BigDecimal; import java.time.LocalDateTime; +import java.util.ArrayList; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -61,8 +62,15 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter; import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.holders.DecimalHolder; +import org.apache.arrow.vector.holders.DurationHolder; +import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; import org.apache.arrow.vector.holders.IntHolder; +import org.apache.arrow.vector.holders.NullableDurationHolder; +import org.apache.arrow.vector.holders.NullableFixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.NullableTimeStampNanoTZHolder; +import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.ArrowType.Int; @@ -354,6 +362,125 @@ public void listDecimalType() { } } + @Test + public void listTimeStampMilliTZType() { + try (ListVector listVector = ListVector.empty("list", allocator)) { + listVector.allocateNew(); + UnionListWriter listWriter = new UnionListWriter(listVector); + for (int i = 0; i < COUNT; i++) { + listWriter.startList(); + for (int j = 0; j < i % 7; j++) { + if (j % 2 == 0) { + listWriter.writeNull(); + } else { + TimeStampMilliTZHolder holder = new TimeStampMilliTZHolder(); + holder.timezone = "FakeTimeZone"; + holder.value = j; + listWriter.timeStampMilliTZ().write(holder); + } + } + listWriter.endList(); + } + listWriter.setValueCount(COUNT); + UnionListReader listReader = new UnionListReader(listVector); + for (int i = 0; i < COUNT; i++) { + listReader.setPosition(i); + for (int j = 0; j < i % 7; j++) { + listReader.next(); + if (j % 2 == 0) { + assertFalse("index is set: " + j, listReader.reader().isSet()); + } else { + NullableTimeStampMilliTZHolder actual = new NullableTimeStampMilliTZHolder(); + listReader.reader().read(actual); + assertEquals(j, actual.value); + assertEquals("FakeTimeZone", actual.timezone); + } + } + } + } + } + + @Test + public void listDurationType() { + try (ListVector listVector = ListVector.empty("list", allocator)) { + listVector.allocateNew(); + UnionListWriter listWriter = new UnionListWriter(listVector); + for (int i = 0; i < COUNT; i++) { + listWriter.startList(); + for (int j = 0; j < i % 7; j++) { + if (j % 2 == 0) { + listWriter.writeNull(); + } else { + DurationHolder holder = new DurationHolder(); + holder.unit = TimeUnit.MICROSECOND; + holder.value = j; + listWriter.duration().write(holder); + } + } + listWriter.endList(); + } + listWriter.setValueCount(COUNT); + UnionListReader listReader = new UnionListReader(listVector); + for (int i = 0; i < COUNT; i++) { + listReader.setPosition(i); + for (int j = 0; j < i % 7; j++) { + listReader.next(); + if (j % 2 == 0) { + assertFalse("index is set: " + j, listReader.reader().isSet()); + } else { + NullableDurationHolder actual = new NullableDurationHolder(); + listReader.reader().read(actual); + assertEquals(TimeUnit.MICROSECOND, actual.unit); + assertEquals(j, actual.value); + } + } + } + } + } + + @Test + public void listFixedSizeBinaryType() throws Exception { + List bufs = new ArrayList(); + try (ListVector listVector = ListVector.empty("list", allocator)) { + listVector.allocateNew(); + UnionListWriter listWriter = new UnionListWriter(listVector); + for (int i = 0; i < COUNT; i++) { + listWriter.startList(); + for (int j = 0; j < i % 7; j++) { + if (j % 2 == 0) { + listWriter.writeNull(); + } else { + ArrowBuf buf = allocator.buffer(4); + buf.setInt(0, j); + FixedSizeBinaryHolder holder = new FixedSizeBinaryHolder(); + holder.byteWidth = 4; + holder.buffer = buf; + listWriter.fixedSizeBinary().write(holder); + bufs.add(buf); + } + } + listWriter.endList(); + } + listWriter.setValueCount(COUNT); + UnionListReader listReader = new UnionListReader(listVector); + for (int i = 0; i < COUNT; i++) { + listReader.setPosition(i); + for (int j = 0; j < i % 7; j++) { + listReader.next(); + if (j % 2 == 0) { + assertFalse("index is set: " + j, listReader.reader().isSet()); + } else { + NullableFixedSizeBinaryHolder actual = new NullableFixedSizeBinaryHolder(); + listReader.reader().read(actual); + assertEquals(j, actual.buffer.getInt(0)); + assertEquals(4, actual.byteWidth); + } + } + } + } + AutoCloseables.close(bufs); + } + @Test public void listScalarTypeNullable() { try (ListVector listVector = ListVector.empty("list", allocator)) { @@ -605,14 +732,33 @@ private void checkListMap(ListVector listVector) { } @Test - public void simpleUnion() { + public void simpleUnion() throws Exception { + List bufs = new ArrayList(); UnionVector vector = new UnionVector("union", allocator, /* field type */ null, /* call-back */ null); UnionWriter unionWriter = new UnionWriter(vector); unionWriter.allocate(); for (int i = 0; i < COUNT; i++) { unionWriter.setPosition(i); - if (i % 2 == 0) { + if (i % 5 == 0) { unionWriter.writeInt(i); + } else if (i % 5 == 1) { + TimeStampMilliTZHolder holder = new TimeStampMilliTZHolder(); + holder.value = (long) i; + holder.timezone = "AsdfTimeZone"; + unionWriter.write(holder); + } else if (i % 5 == 2) { + DurationHolder holder = new DurationHolder(); + holder.value = (long) i; + holder.unit = TimeUnit.NANOSECOND; + unionWriter.write(holder); + } else if (i % 5 == 3) { + FixedSizeBinaryHolder holder = new FixedSizeBinaryHolder(); + ArrowBuf buf = allocator.buffer(4); + buf.setInt(0, i); + holder.byteWidth = 4; + holder.buffer = buf; + unionWriter.write(holder); + bufs.add(buf); } else { unionWriter.writeFloat4((float) i); } @@ -621,13 +767,29 @@ public void simpleUnion() { UnionReader unionReader = new UnionReader(vector); for (int i = 0; i < COUNT; i++) { unionReader.setPosition(i); - if (i % 2 == 0) { + if (i % 5 == 0) { Assert.assertEquals(i, i, unionReader.readInteger()); + } else if (i % 5 == 1) { + NullableTimeStampMilliTZHolder holder = new NullableTimeStampMilliTZHolder(); + unionReader.read(holder); + Assert.assertEquals(i, holder.value); + Assert.assertEquals("AsdfTimeZone", holder.timezone); + } else if (i % 5 == 2) { + NullableDurationHolder holder = new NullableDurationHolder(); + unionReader.read(holder); + Assert.assertEquals(i, holder.value); + Assert.assertEquals(TimeUnit.NANOSECOND, holder.unit); + } else if (i % 5 == 3) { + NullableFixedSizeBinaryHolder holder = new NullableFixedSizeBinaryHolder(); + unionReader.read(holder); + assertEquals(i, holder.buffer.getInt(0)); + assertEquals(4, holder.byteWidth); } else { Assert.assertEquals((float) i, unionReader.readFloat(), 1e-12); } } vector.close(); + AutoCloseables.close(bufs); } @Test