diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index bc2b0d18e72..d82974bc221 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -269,10 +269,10 @@ class SchemaWriter { writer_->Key("mode"); switch (type.mode()) { case UnionMode::SPARSE: - writer_->String("SPARSE"); + writer_->String("Sparse"); break; case UnionMode::DENSE: - writer_->String("DENSE"); + writer_->String("Dense"); break; } @@ -569,7 +569,7 @@ class ArrayWriter { WriteValidityField(array); const auto& type = static_cast(*array.type()); - WriteIntegerField("TYPE_ID", array.raw_type_ids(), array.length()); + WriteIntegerField("TYPE", array.raw_type_ids(), array.length()); if (type.mode() == UnionMode::DENSE) { WriteIntegerField("OFFSET", array.raw_value_offsets(), array.length()); } @@ -763,9 +763,9 @@ static Status GetUnion(const RjObject& json_type, std::string mode_str = it_mode->value.GetString(); UnionMode mode; - if (mode_str == "SPARSE") { + if (mode_str == "Sparse") { mode = UnionMode::SPARSE; - } else if (mode_str == "DENSE") { + } else if (mode_str == "Dense") { mode = UnionMode::DENSE; } else { std::stringstream ss; @@ -774,13 +774,25 @@ static Status GetUnion(const RjObject& json_type, } const auto& it_type_codes = json_type.FindMember("typeIds"); - RETURN_NOT_ARRAY("typeIds", it_type_codes, json_type); std::vector type_codes; - const auto& id_array = it_type_codes->value.GetArray(); - for (const rj::Value& val : id_array) { - DCHECK(val.IsUint()); - type_codes.push_back(static_cast(val.GetUint())); + if (it_type_codes == json_type.MemberEnd()) { + for (uint8_t code = 0; code < static_cast(children.size()); ++code) { + type_codes.push_back(code); + } + } else { + RETURN_NOT_ARRAY("typeIds", it_type_codes, json_type); + const auto& id_array = it_type_codes->value.GetArray(); + if (id_array.Size() == 0) { + for (uint8_t code = 0; code < static_cast(children.size()); ++code) { + type_codes.push_back(code); + } + } else { + for (const rj::Value& val : id_array) { + DCHECK(val.IsUint()); + type_codes.push_back(static_cast(val.GetUint())); + } + } } *type = union_(children, type_codes, mode); @@ -1142,8 +1154,8 @@ class ArrayReader { RETURN_NOT_OK(GetValidityBuffer(is_valid_, &null_count, &validity_buffer)); - const auto& json_type_ids = obj_->FindMember("TYPE_ID"); - RETURN_NOT_ARRAY("TYPE_ID", json_type_ids, *obj_); + const auto& json_type_ids = obj_->FindMember("TYPE"); + RETURN_NOT_ARRAY("TYPE", json_type_ids, *obj_); RETURN_NOT_OK( GetIntArray(json_type_ids->value.GetArray(), length_, &type_id_buffer)); diff --git a/integration/data/union.json b/integration/data/union.json new file mode 100644 index 00000000000..f8e223ce372 --- /dev/null +++ b/integration/data/union.json @@ -0,0 +1,82 @@ +{ + "schema" : { + "fields" : [{ + "name" : "union", + "nullable" : true, + "type" : { + "name" : "union", + "mode" : "Sparse", + "typeIds" : [4,5] + }, + "children" : [{ + "name" : "int", + "nullable" : true, + "type" : { + "name" : "int", + "bitWidth" : 32, + "isSigned" : true + }, + "children" : [ ], + "typeLayout" : { + "vectors" : [{ + "type" : "VALIDITY", + "typeBitWidth" : 1 + },{ + "type" : "DATA", + "typeBitWidth" : 32 + }] + } + },{ + "name" : "bigint", + "nullable" : true, + "type" : { + "name" : "int", + "bitWidth" : 64, + "isSigned" : true + }, + "children" : [ ], + "typeLayout" : { + "vectors" : [{ + "type" : "VALIDITY", + "typeBitWidth" : 1 + },{ + "type" : "DATA", + "typeBitWidth" : 64 + }] + } + }], + "typeLayout" : { + "vectors" : [ + { + "type": "VALIDITY", + "typeBitWidth": 1 + }, + { + "type" : "TYPE", + "typeBitWidth" : 8 + } + ] + } + }] + }, + "batches" : [{ + "count" : 10, + "columns" : [{ + "name" : "union", + "count" : 10, + "VALIDITY" : [1,1,1,1,1,1,1,1,1,1], + "TYPE" : [4,5,4,5,4,5,4,5,4,5], + "children" : [{ + "name" : "int", + "count" : 10, + "VALIDITY" : [1,0,1,0,1,0,1,0,1,0], + "DATA" : [0,0,2,0,4,0,6,0,8,0] + },{ + "name" : "bigint", + "count" : 10, + "VALIDITY" : [0,1,0,1,0,1,0,1,0,1], + "DATA" : [0,1,0,3,0,5,0,7,0,9] + }] + }] + }] +} diff --git a/integration/integration_test.py b/integration/integration_test.py index 46539484488..23fe4c540a5 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -533,6 +533,55 @@ def generate_column(self, size, name=None): return StructColumn(name, size, is_valid, field_values) +class UnionType(DataType): + + def __init__(self, name, mode, type_ids, field_types, nullable=True): + DataType.__init__(self, name, nullable=nullable) + self.mode = mode + self.type_ids = type_ids + self.field_types = field_types + + def _get_type(self): + type_ids = self.type_ids if self.type_ids is not None else [] + + attrs = [ + ('name', 'union'), + ('mode', self.mode), + ('typeIds', type_ids) + ] + + return OrderedDict(attrs) + + def _get_children(self): + return [type_.get_json() for type_ in self.field_types] + + def _get_type_layout(self): + return OrderedDict([ + ('vectors', + [OrderedDict([('type', 'VALIDITY'), + ('typeBitWidth', 1)]), + OrderedDict([('type', 'TYPE'), + ('typeBitWidth', 8)])])]) + + def _make_type(self, size): + if self.type_ids is not None: + type_ids = self.type_ids + else: + type_ids = np.arange(len(self.field_types)) + + return np.random.choice(type_ids, size) + + def generate_column(self, size, name=None): + is_valid = self._make_is_valid(size) + types = self._make_type(size) + + field_values = [type_.generate_column(size) + for type_ in self.field_types] + if name is None: + name = self.name + return UnionColumn(name, size, is_valid, types, field_values) + + class Dictionary(object): def __init__(self, id_, field, values, ordered=False): @@ -603,6 +652,23 @@ def _get_children(self): return [field.get_json() for field in self.field_values] +class UnionColumn(Column): + def __init__(self, name, count, is_valid, types, field_values): + Column.__init__(self, name, count) + self.is_valid = is_valid + self.types = types + self.field_values = field_values + + def _get_buffers(self): + return [ + ('VALIDITY', [int(v) for v in self.is_valid]), + ('TYPE', [int(v) for v in self.types]) + ] + + def _get_children(self): + return [field.get_json() for field in self.field_values] + + class JsonRecordBatch(object): def __init__(self, count, columns): @@ -747,6 +813,22 @@ def generate_dictionary_case(): dictionaries=[dict1, dict2]) +def _generate_union_field(type_ids=None): + return UnionType('union_nullable', "Sparse", type_ids, + [get_field('f1', 'int64'), + get_field('f2', 'float64'), + get_field('f3', 'utf8'), + get_field('f4', 'binary'), + StructType('f5', [get_field('f1', 'int32'), + get_field('f2', 'utf8')]),]) + + +def generate_union_case(): + type_ids = np.random.choice(range(128), 5, replace=False).tolist() + fields = [_generate_union_field(type_ids)] + return _generate_file("union", fields, [10]) + + def get_generated_json_files(): temp_dir = tempfile.mkdtemp() @@ -758,7 +840,8 @@ def _temp_path(): generate_primitive_case([0, 0, 0]), generate_datetime_case(), generate_nested_case(), - generate_dictionary_case() + generate_dictionary_case(), + generate_union_case(), ] generated_paths = [] @@ -948,6 +1031,7 @@ def get_static_json_files(): def run_all_tests(debug=False): testers = [CPPTester(debug=debug), JavaTester(debug=debug)] + static_json_files = get_static_json_files() generated_json_files = get_generated_json_files() json_files = static_json_files + generated_json_files diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java index d2b35e65a81..7e99ce113d7 100644 --- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java +++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java @@ -187,7 +187,7 @@ public void execute(File arrowFile, File jsonFile) throws IOException { LOGGER.debug("ARROW schema: " + arrowSchema); LOGGER.debug("JSON Input file size: " + jsonFile.length()); LOGGER.debug("JSON schema: " + jsonSchema); - Validator.compareSchemas(jsonSchema, arrowSchema); + Validator.compareSchemas(arrowSchema, jsonSchema); List recordBatches = arrowReader.getRecordBlocks(); Iterator iterator = recordBatches.iterator(); diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index fe24a8674bd..85f0b0ae8a0 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -15,6 +15,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ + +import java.util.ArrayList; +import java.util.Arrays; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.complex.VectorWithOrdinal; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.util.CallBack; + <@pp.dropOutputFile /> <@pp.changeOutputFile name="/org/apache/arrow/vector/complex/UnionVector.java" /> @@ -54,14 +65,18 @@ */ public class UnionVector implements FieldVector { - private String name; - private BufferAllocator allocator; - private Accessor accessor = new Accessor(); - private Mutator mutator = new Mutator(); + private static final int MAX_TYPE_ID = 128; + private static final int MAX_MINOR_TYPE = 128; + + private final String name; + private final BufferAllocator allocator; + private final Accessor accessor; + private final Mutator mutator; int valueCount; - MapVector internalMap; - UInt1Vector typeVector; + final BitVector bits; + final UInt1Vector typeVector; + final MapVector internalMap; private NullableMapVector mapVector; private ListVector listVector; @@ -74,13 +89,39 @@ public class UnionVector implements FieldVector { private final CallBack callBack; private final List innerVectors; - public UnionVector(String name, BufferAllocator allocator, CallBack callBack) { + private int[] typeIds; + private int[] minorTypeToTypeId; + private int[] typeIdToMinorType; + private String[] minorTypeToFieldName; + + public UnionVector(String name, BufferAllocator allocator, CallBack callback) { + this(name, allocator, callback, null); + } + + public UnionVector(String name, BufferAllocator allocator, CallBack callBack, int[] typeIds) { this.name = name; this.allocator = allocator; - this.internalMap = new MapVector("internal", allocator, new FieldType(false, ArrowType.Struct.INSTANCE, null, null), callBack); - this.typeVector = new UInt1Vector("types", allocator); + + this.bits = new BitVector("$bits$", allocator); + this.typeVector = new UInt1Vector("$types$", allocator); + this.internalMap = new MapVector("$internal$", allocator, new FieldType(false, ArrowType.Struct.INSTANCE, null, null), callBack); + this.callBack = callBack; - this.innerVectors = Collections.unmodifiableList(Arrays.asList(typeVector)); + this.innerVectors = Collections.unmodifiableList(Arrays.asList(bits, typeVector)); + + this.typeIds = new int[MAX_TYPE_ID]; + Arrays.fill(this.typeIds, -1); + if (typeIds != null) { + System.arraycopy(typeIds, 0, this.typeIds, 0, typeIds.length); + } + this.minorTypeToTypeId = new int[MAX_MINOR_TYPE]; + Arrays.fill(this.minorTypeToTypeId, -1); + this.typeIdToMinorType = new int[MAX_TYPE_ID]; + Arrays.fill(this.typeIdToMinorType, -1); + this.minorTypeToFieldName = new String[MAX_MINOR_TYPE]; + + this.accessor = new Accessor(); + this.mutator = new Mutator(); } public BufferAllocator getAllocator() { @@ -92,9 +133,40 @@ public MinorType getMinorType() { return MinorType.UNION; } + /** + * Get type id from a child index or assign one if the child index has not been assigned a type id. + * + * The assigned type id is the first avaiable id enumerating from 0. + * + * @return the type id for the child index + */ + private int getOrAssignTypeId(int index) { + int typeId = typeIds[index]; + + if (typeId < 0) { + for (int i = 0; i < MAX_TYPE_ID; i++) { + if (typeIdToMinorType[i] < 0) { + typeId = i; + break; + } + } + typeIds[index] = typeId; + } + + return typeId; + } + @Override public void initializeChildrenFromFields(List children) { internalMap.initializeChildrenFromFields(children); + for (int i = 0; i < children.size(); ++i) { + Field child = children.get(i); + MinorType minorType = Types.getMinorTypeForArrowType(child.getType()); + + int typeId = getOrAssignTypeId(i); + setTypeMapping(minorType, typeId); + setFieldName(minorType, child.getName()); + } } @Override @@ -104,8 +176,9 @@ public List getChildrenFromFields() { @Override public void loadFieldBuffers(ArrowFieldNode fieldNode, List ownBuffers) { + // TODO: Check if need to truncate bits vector // truncate types vector buffer to size (#0) - org.apache.arrow.vector.BaseDataValueVector.truncateBufferBasedOnSize(ownBuffers, 0, typeVector.getBufferSizeFor(fieldNode.getLength())); + org.apache.arrow.vector.BaseDataValueVector.truncateBufferBasedOnSize(ownBuffers, 1, typeVector.getBufferSizeFor(fieldNode.getLength())); BaseDataValueVector.load(fieldNode, getFieldInnerVectors(), ownBuffers); this.valueCount = fieldNode.getLength(); } @@ -129,12 +202,38 @@ private FieldType fieldType(MinorType type) { } private T addOrGet(MinorType minorType, Class c) { - return internalMap.addOrGet(fieldName(minorType), fieldType(minorType), c); + String fieldName = minorTypeToFieldName[minorType.ordinal()]; + if (fieldName == null) { + fieldName = fieldName(minorType); + setFieldName(minorType, fieldName); + } + + T typedVector = internalMap.addOrGet(fieldName, fieldType(minorType), c); + int index = internalMap.getChildVectorWithOrdinal(fieldName).ordinal; + + int typeId = getOrAssignTypeId(index); + setTypeMapping(minorType, typeId); + + return typedVector; + } + + private void setTypeMapping(Types.MinorType minorType, int typeId) { + minorTypeToTypeId[minorType.ordinal()] = typeId; + typeIdToMinorType[typeId] = minorType.ordinal(); + } + + private void setFieldName(Types.MinorType type, String fieldName) { + if (minorTypeToFieldName[type.ordinal()] != null) { + throw new IllegalArgumentException( + String.format("Vector of minor type %s already exists.", type.toString())); + } + + minorTypeToFieldName[type.ordinal()] = fieldName; } @Override public long getValidityBufferAddress() { - return typeVector.getDataBuffer().memoryAddress(); + return bits.getBuffer().memoryAddress(); } @Override @@ -211,7 +310,8 @@ public ListVector getList() { } public int getTypeValue(int index) { - return typeVector.getAccessor().get(index); + int typeId = typeVector.getAccessor().get(index); + return typeIdToMinorType[typeId]; } public UInt1Vector getTypeVector() { @@ -221,20 +321,22 @@ public UInt1Vector getTypeVector() { @Override public void allocateNew() throws OutOfMemoryException { internalMap.allocateNew(); + bits.allocateNew(); + bits.zeroVector(); typeVector.allocateNew(); - if (typeVector != null) { - typeVector.zeroVector(); - } + typeVector.zeroVector(); } @Override public boolean allocateNewSafe() { boolean safe = internalMap.allocateNewSafe(); + safe = safe && bits.allocateNewSafe(); + if (safe) { + bits.zeroVector(); + } safe = safe && typeVector.allocateNewSafe(); if (safe) { - if (typeVector != null) { - typeVector.zeroVector(); - } + typeVector.zeroVector(); } return safe; } @@ -242,6 +344,7 @@ public boolean allocateNewSafe() { @Override public void reAlloc() { internalMap.reAlloc(); + bits.reAlloc(); typeVector.reAlloc(); } @@ -261,6 +364,7 @@ public void close() { @Override public void clear() { + bits.clear(); typeVector.clear(); internalMap.clear(); } @@ -269,12 +373,14 @@ public void clear() { public Field getField() { List childFields = new ArrayList<>(); List children = internalMap.getChildren(); - int[] typeIds = new int[children.size()]; for (ValueVector v : children) { - typeIds[childFields.size()] = v.getMinorType().ordinal(); childFields.add(v.getField()); } - return new Field(name, FieldType.nullable(new ArrowType.Union(Sparse, typeIds)), childFields); + + int[] typeIds = Arrays.copyOfRange(this.typeIds, 0, children.size()); + + return new Field(name, + FieldType.nullable(new ArrowType.Union(Sparse, typeIds)), childFields); } @Override @@ -322,24 +428,30 @@ public FieldVector addVector(FieldVector v) { private class TransferImpl implements TransferPair { private final TransferPair internalMapVectorTransferPair; private final TransferPair typeVectorTransferPair; + private final TransferPair bitsTransferPair; private final UnionVector to; public TransferImpl(String name, BufferAllocator allocator, CallBack callBack) { - to = new UnionVector(name, allocator, callBack); - internalMapVectorTransferPair = internalMap.makeTransferPair(to.internalMap); - typeVectorTransferPair = typeVector.makeTransferPair(to.typeVector); + this(new UnionVector(name, allocator, callBack)); } public TransferImpl(UnionVector to) { this.to = to; internalMapVectorTransferPair = internalMap.makeTransferPair(to.internalMap); typeVectorTransferPair = typeVector.makeTransferPair(to.typeVector); + bitsTransferPair = bits.makeTransferPair(to.bits); + + this.to.typeIds = typeIds; + this.to.minorTypeToTypeId = minorTypeToTypeId; + this.to.typeIdToMinorType = typeIdToMinorType; + this.to.minorTypeToFieldName = minorTypeToFieldName; } @Override public void transfer() { internalMapVectorTransferPair.transfer(); typeVectorTransferPair.transfer(); + bitsTransferPair.transfer(); to.valueCount = valueCount; } @@ -422,11 +534,13 @@ public Iterator iterator() { } public class Accessor extends BaseValueVector.BaseAccessor { + final BitVector.Accessor bAccessor = bits.getAccessor(); @Override public Object getObject(int index) { - int type = typeVector.getAccessor().get(index); - switch (MinorType.values()[type]) { + int minorType = typeIdToMinorType[typeVector.getAccessor().get(index)]; + assert minorType >= 0; + switch (Types.MinorType.values()[minorType]) { case NULL: return null; <#list vv.types as type> @@ -445,7 +559,7 @@ public Object getObject(int index) { case LIST: return getList().getAccessor().getObject(index); default: - throw new UnsupportedOperationException("Cannot support type: " + MinorType.values()[type]); + throw new UnsupportedOperationException("Cannot support type: " + MinorType.values()[minorType]); } } @@ -469,11 +583,11 @@ public int getValueCount() { @Override public boolean isNull(int index) { - return typeVector.getAccessor().get(index) == 0; + return isSet(index) == 0; } public int isSet(int index) { - return isNull(index) ? 0 : 1; + return bAccessor.get(index); } } @@ -484,6 +598,7 @@ public class Mutator extends BaseValueVector.BaseMutator { @Override public void setValueCount(int valueCount) { UnionVector.this.valueCount = valueCount; + bits.getMutator().setValueCount(valueCount); typeVector.getMutator().setValueCount(valueCount); internalMap.getMutator().setValueCount(valueCount); } @@ -500,7 +615,7 @@ public void setSafe(int index, UnionHolder holder) { <#list type.minor as minor> <#assign name = minor.class?cap_first /> <#assign fields = minor.fields!type.fields /> - <#assign uncappedName = name?uncap_first/> + <#assign uncappedName = name?uncap_first /> <#if !minor.typeParams?? > case ${name?upper_case}: Nullable${name}Holder ${uncappedName}Holder = new Nullable${name}Holder(); @@ -537,12 +652,46 @@ public void setSafe(int index, Nullable${name}Holder holder) { + public void setNull(int index) {bits.getMutator().setSafe(index, 0);} + public void setType(int index, MinorType type) { - typeVector.getMutator().setSafe(index, (byte) type.ordinal()); + int typeId = minorTypeToTypeId[type.ordinal()]; + + // Intialize vector and typeId if it doesn't exist + if (typeId < 0) { + switch (type) { + <#list vv.types as type> + <#list type.minor as minor> + <#assign name = minor.class ? cap_first /> + <#assign fields = minor.fields!type.fields /> + <#assign uncappedName = name?uncap_first /> + <#if !minor.typeParams?? > + case ${name?upper_case}: + get${name}Vector(); + break; + + + + case MAP: + getMap(); + break; + case LIST: + getList(); + break; + default: + throw new UnsupportedOperationException("Cannot support type: " + type); + } + + typeId = minorTypeToTypeId[type.ordinal()]; + } + + assert typeId >= 0; + bits.getMutator().setSafeToOne(index); + typeVector.getMutator().setSafe(index, (byte) typeId); } @Override - public void reset() { } + public void reset() { } @Override public void generateTestData(int values) { } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java index db0ff86df47..064758328df 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/AbstractContainerVector.java @@ -22,9 +22,10 @@ import org.apache.arrow.memory.OutOfMemoryException; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.UnionMode; import org.apache.arrow.vector.types.pojo.ArrowType.List; import org.apache.arrow.vector.types.pojo.ArrowType.Struct; +import org.apache.arrow.vector.types.pojo.ArrowType.Union; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.CallBack; @@ -109,6 +110,6 @@ public ListVector addOrGetList(String name) { } public UnionVector addOrGetUnion(String name) { - return addOrGet(name, FieldType.nullable(MinorType.UNION.getType()), UnionVector.class); + return addOrGet(name, FieldType.nullable(new Union(UnionMode.Sparse, null)), UnionVector.class); } } \ No newline at end of file diff --git a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java index 484a82fdaab..86605e1b401 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/file/json/JsonFileReader.java @@ -66,7 +66,6 @@ import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; import org.apache.arrow.vector.VectorSchemaRoot; -import org.apache.arrow.vector.complex.NullableMapVector; import org.apache.arrow.vector.dictionary.Dictionary; import org.apache.arrow.vector.dictionary.DictionaryProvider; import org.apache.arrow.vector.schema.ArrowVectorType; @@ -215,6 +214,11 @@ public VectorSchemaRoot read() throws IOException { } } + /* + * TODO: This method doesn't load some vectors correctly. For instance, it doesn't set `lastSet` + * in ListVector, VarCharVector, NullableVarBinaryVector A better way of implementing this + * function is to use `loadFieldBuffers` methods in FieldVector. + */ private void readVector(Field field, FieldVector vector) throws JsonParseException, IOException { List vectorTypes = field.getTypeLayout().getVectorTypes(); List fieldInnerVectors = vector.getFieldInnerVectors(); @@ -229,6 +233,8 @@ private void readVector(Field field, FieldVector vector) throws JsonParseExcepti throw new IllegalArgumentException("Expected field " + field.getName() + " but got " + name); } int count = readNextField("count", Integer.class); + vector.allocateNew(); + vector.getMutator().setValueCount(count); for (int v = 0; v < vectorTypes.size(); v++) { ArrowVectorType vectorType = vectorTypes.get(v); BufferBacked innerVector = fieldInnerVectors.get(v); @@ -262,9 +268,6 @@ private void readVector(Field field, FieldVector vector) throws JsonParseExcepti } readToken(END_ARRAY); } - if (vector instanceof NullableMapVector) { - ((NullableMapVector) vector).valueCount = count; - } } readToken(END_OBJECT); } diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/TypeLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/TypeLayout.java index 29407bf1ab4..b37fd03fd60 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/schema/TypeLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/TypeLayout.java @@ -81,7 +81,8 @@ public TypeLayout visit(Union type) { break; case Sparse: vectors = asList( - typeVector() // type of the value at the index or 0 if null + validityVector(), + typeVector() ); break; default: diff --git a/java/vector/src/main/java/org/apache/arrow/vector/schema/VectorLayout.java b/java/vector/src/main/java/org/apache/arrow/vector/schema/VectorLayout.java index 0871baf38ed..6a1aa07d5a1 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/schema/VectorLayout.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/schema/VectorLayout.java @@ -32,7 +32,7 @@ public class VectorLayout implements FBSerializable { private static final VectorLayout VALIDITY_VECTOR = new VectorLayout(VALIDITY, 1); private static final VectorLayout OFFSET_VECTOR = new VectorLayout(OFFSET, 32); - private static final VectorLayout TYPE_VECTOR = new VectorLayout(TYPE, 32); + private static final VectorLayout TYPE_VECTOR = new VectorLayout(TYPE, 8); private static final VectorLayout BOOLEAN_VECTOR = new VectorLayout(DATA, 1); private static final VectorLayout VALUES_64 = new VectorLayout(DATA, 64); private static final VectorLayout VALUES_32 = new VectorLayout(DATA, 32); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java index c57dd6dafe9..345f376b805 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/types/Types.java @@ -453,13 +453,13 @@ public FieldWriter getNewFieldWriter(ValueVector vector) { throw new UnsupportedOperationException("FieldWriter not implemented for FixedSizeList type"); } }, - UNION(new Union(Sparse, null)) { + UNION(null) { @Override public FieldVector getNewVector(String name, FieldType fieldType, BufferAllocator allocator, CallBack schemaChangeCallback) { if (fieldType.getDictionary() != null) { throw new UnsupportedOperationException("Dictionary encoding not supported for complex types"); } - return new UnionVector(name, allocator, schemaChangeCallback); + return new UnionVector(name, allocator, schemaChangeCallback, ((Union) fieldType.getType()).getTypeIds()); } @Override diff --git a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java index 5851bd5fa5d..0091f4a91ad 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/util/Validator.java @@ -46,7 +46,7 @@ public class Validator { */ public static void compareSchemas(Schema schema1, Schema schema2) { if (!schema2.equals(schema1)) { - throw new IllegalArgumentException("Different schemas:\n" + schema2 + "\n" + schema1); + throw new IllegalArgumentException("Different schemas:\n" + schema1 + "\n" + schema2); } } @@ -122,8 +122,18 @@ public static void compareFieldVectors(FieldVector vector1, FieldVector vector2) Object obj1 = vector1.getAccessor().getObject(j); Object obj2 = vector2.getAccessor().getObject(j); if (!equals(field1.getType(), obj1, obj2)) { + String obj1Str; + String obj2Str; + if (obj1 instanceof byte[] && obj2 instanceof byte[]) { + obj1Str = Arrays.toString((byte[]) obj1); + obj2Str = Arrays.toString((byte[]) obj2); + } else { + obj1Str = obj1.toString(); + obj2Str = obj2.toString(); + } + throw new IllegalArgumentException( - "Different values in column:\n" + field1 + " at index " + j + ": " + obj1 + " != " + obj2); + "Different values in column:\n" + field1 + " at index " + j + ": " + obj1Str + " != " + obj2Str); } } } @@ -140,7 +150,7 @@ static boolean equals(ArrowType type, final Object o1, final Object o2) { default: throw new UnsupportedOperationException("unsupported precision: " + fpType); } - } else if (type instanceof ArrowType.Binary) { + } else if (o1 instanceof byte[] && o2 instanceof byte[]) { return Arrays.equals((byte[]) o1, (byte[]) o2); } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java index 86f0bf337f9..4f5772d0360 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestUnionVector.java @@ -18,9 +18,11 @@ package org.apache.arrow.vector; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import java.util.Arrays; import java.util.List; import org.apache.arrow.memory.BufferAllocator; @@ -31,6 +33,7 @@ import org.apache.arrow.vector.holders.NullableFloat4Holder; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.util.TransferPair; import org.junit.After; import org.junit.Before; @@ -72,15 +75,21 @@ public void testUnionVector() throws Exception { mutator.setValueCount(4); // check that what we wrote is correct + int[] typeIds = ((ArrowType.Union) unionVector.getField().getFieldType().getType()) + .getTypeIds(); + assertArrayEquals(new int[] {0}, typeIds); + final UnionVector.Accessor accessor = unionVector.getAccessor(); assertEquals(4, accessor.getValueCount()); assertEquals(false, accessor.isNull(0)); + assertEquals(MinorType.UINT4.ordinal(), unionVector.getTypeValue(0)); assertEquals(100, accessor.getObject(0)); assertEquals(true, accessor.isNull(1)); assertEquals(false, accessor.isNull(2)); + assertEquals(MinorType.UINT4.ordinal(), unionVector.getTypeValue(2)); assertEquals(100, accessor.getObject(2)); assertEquals(true, accessor.isNull(3)); @@ -357,7 +366,7 @@ public void testGetBufferAddress() throws Exception { assertTrue(error); } - assertEquals(1, buffers.size()); + assertEquals(2, buffers.size()); assertEquals(bitAddress, buffers.get(0).memoryAddress()); } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java index 24b2138386d..d1436502127 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/file/json/TestJSONFile.java @@ -30,6 +30,7 @@ import org.apache.arrow.vector.dictionary.DictionaryProvider.MapDictionaryProvider; import org.apache.arrow.vector.file.BaseFileTest; import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.Validator; import org.junit.Assert; import org.junit.Test; import org.slf4j.Logger; @@ -95,29 +96,31 @@ public void testWriteReadUnionJSON() throws IOException { int count = COUNT; try ( BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE); - NullableMapVector parent = NullableMapVector.empty("parent", vectorAllocator)) { - + NullableMapVector parent = NullableMapVector.empty("parent", vectorAllocator) + ) { writeUnionData(count, parent); - printVectors(parent.getChildrenFromFields()); - VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")); - validateUnionData(count, root); - - writeJSON(file, root, null); - } - // read - try ( - BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE); - BufferAllocator vectorAllocator = allocator.newChildAllocator("final vectors", 0, Integer.MAX_VALUE); - ) { - JsonFileReader reader = new JsonFileReader(file, readerAllocator); - Schema schema = reader.start(); - LOGGER.debug("reading schema: " + schema); - - // initialize vectors - try (VectorSchemaRoot root = reader.read();) { + try ( + VectorSchemaRoot root = new VectorSchemaRoot(parent.getChild("root")) + ) { validateUnionData(count, root); + writeJSON(file, root, null); + + // read + try ( + BufferAllocator readerAllocator = allocator.newChildAllocator("reader", 0, Integer.MAX_VALUE) + ) { + JsonFileReader reader = new JsonFileReader(file, readerAllocator); + + Schema schema = reader.start(); + LOGGER.debug("reading schema: " + schema); + + try (VectorSchemaRoot rootFromJson = reader.read();) { + validateUnionData(count, rootFromJson); + Validator.compareVectorSchemaRoot(root, rootFromJson); + } + } } } }