From 3185af9fab999a33d4f7095a140134087b424a62 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Mon, 24 Jun 2019 18:12:10 -0700 Subject: [PATCH 1/3] Fix MapType to create a non-nullable struct field --- cpp/src/arrow/ipc/metadata-internal.cc | 4 +--- cpp/src/arrow/type.cc | 7 +++++-- integration/integration_test.py | 20 ++++---------------- 3 files changed, 10 insertions(+), 21 deletions(-) diff --git a/cpp/src/arrow/ipc/metadata-internal.cc b/cpp/src/arrow/ipc/metadata-internal.cc index 4b349cb4a69..4e1a1576ddb 100644 --- a/cpp/src/arrow/ipc/metadata-internal.cc +++ b/cpp/src/arrow/ipc/metadata-internal.cc @@ -320,9 +320,7 @@ Status ConcreteTypeFromFlatbuffer(flatbuf::Type type, const void* type_data, if (children.size() != 1) { return Status::Invalid("Map must have exactly 1 child field"); } - if ( // FIXME(bkietz) temporarily disabled: this field is sometimes read nullable - // children[0]->nullable() || - children[0]->type()->id() != Type::STRUCT || + if (children[0]->nullable() || children[0]->type()->id() != Type::STRUCT || children[0]->type()->num_children() != 2) { return Status::Invalid("Map's key-item pairs must be non-nullable structs"); } diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index b21533121c5..44823657d63 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -153,8 +153,11 @@ std::string ListType::ToString() const { MapType::MapType(const std::shared_ptr& key_type, const std::shared_ptr& item_type, bool keys_sorted) - : ListType(struct_({std::make_shared("key", key_type, false), - std::make_shared("item", item_type)})), + : ListType(std::make_shared( + "$data$", + struct_({std::make_shared("key", key_type, false), + std::make_shared("item", item_type)}), + false)), keys_sorted_(keys_sorted) { id_ = type_id; } diff --git a/integration/integration_test.py b/integration/integration_test.py index aca05747c72..5e4141a84ad 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -714,7 +714,7 @@ def __init__(self, name, key_type, item_type, nullable=True, assert not key_type.nullable self.key_type = key_type self.item_type = item_type - self.pair_type = StructType('item', [key_type, item_type], False) + self.pair_type = StructType('$data$', [key_type, item_type], False) self.keysSorted = keysSorted def _get_type(self): @@ -1058,18 +1058,6 @@ def generate_interval_case(): return _generate_file("interval", fields, batch_sizes) -def generate_map_case(): - # TODO(bkietz): separated from nested_case so it can be - # independently skipped, consolidate after Java supports map - fields = [ - MapType('map_nullable', get_field('key', 'utf8', False), - get_field('item', 'int32')), - ] - - batch_sizes = [7, 10] - return _generate_file("map", fields, batch_sizes) - - def generate_nested_case(): fields = [ ListType('list_nullable', get_field('item', 'int32')), @@ -1077,6 +1065,8 @@ def generate_nested_case(): get_field('item', 'int32'), 4), StructType('struct_nullable', [get_field('f1', 'int32'), get_field('f2', 'utf8')]), + MapType('map_nullable', get_field('key', 'utf8', False), + get_field('item', 'int32')), # TODO(wesm): this causes segfault # ListType('list_nonnullable', get_field('item', 'int32'), False), @@ -1150,7 +1140,6 @@ def _temp_path(): generate_decimal_case(), generate_datetime_case(), generate_interval_case(), - generate_map_case(), generate_nested_case(), generate_dictionary_case(), generate_nested_dictionary_case().skip_category(SKIP_ARROW) @@ -1221,7 +1210,7 @@ def _compare_implementations(self, producer, consumer): file_id = guid()[:8] if (('JS' in (producer.name, consumer.name) or - 'Java' in (producer.name, consumer.name)) and + 'Java' in (producer.name, consumer.name)) and "map" in test_case.name): print('TODO(ARROW-1279): Enable map tests ' + ' for Java and JS once Java supports them and JS\'' + @@ -1707,7 +1696,6 @@ def run_all_tests(args): def write_js_test_json(directory): - generate_map_case().write(os.path.join(directory, 'map.json')) generate_nested_case().write(os.path.join(directory, 'nested.json')) generate_decimal_case().write(os.path.join(directory, 'decimal.json')) generate_datetime_case().write(os.path.join(directory, 'datetime.json')) From 30777804fd39facfaa9feb162582f161fec35cca Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 25 Jun 2019 00:13:38 -0700 Subject: [PATCH 2/3] rename MapType struct field to 'entries' --- cpp/src/arrow/type.cc | 4 +-- integration/integration_test.py | 24 +++++++++----- .../arrow/vector/complex/MapVector.java | 31 +++++++++++++++++-- 3 files changed, 48 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 44823657d63..b3fcb0cde84 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -154,9 +154,9 @@ std::string ListType::ToString() const { MapType::MapType(const std::shared_ptr& key_type, const std::shared_ptr& item_type, bool keys_sorted) : ListType(std::make_shared( - "$data$", + "entries", struct_({std::make_shared("key", key_type, false), - std::make_shared("item", item_type)}), + std::make_shared("value", item_type)}), false)), keys_sorted_(keys_sorted) { id_ = type_id; diff --git a/integration/integration_test.py b/integration/integration_test.py index 5e4141a84ad..dbc03c73a1a 100644 --- a/integration/integration_test.py +++ b/integration/integration_test.py @@ -714,7 +714,7 @@ def __init__(self, name, key_type, item_type, nullable=True, assert not key_type.nullable self.key_type = key_type self.item_type = item_type - self.pair_type = StructType('$data$', [key_type, item_type], False) + self.pair_type = StructType('entries', [key_type, item_type], False) self.keysSorted = keysSorted def _get_type(self): @@ -1058,6 +1058,18 @@ def generate_interval_case(): return _generate_file("interval", fields, batch_sizes) +def generate_map_case(): + # TODO(bkietz): separated from nested_case so it can be + # independently skipped, consolidate after JS supports map + fields = [ + MapType('map_nullable', get_field('key', 'utf8', False), + get_field('value', 'int32')), + ] + + batch_sizes = [7, 10] + return _generate_file("map", fields, batch_sizes) + + def generate_nested_case(): fields = [ ListType('list_nullable', get_field('item', 'int32')), @@ -1065,8 +1077,6 @@ def generate_nested_case(): get_field('item', 'int32'), 4), StructType('struct_nullable', [get_field('f1', 'int32'), get_field('f2', 'utf8')]), - MapType('map_nullable', get_field('key', 'utf8', False), - get_field('item', 'int32')), # TODO(wesm): this causes segfault # ListType('list_nonnullable', get_field('item', 'int32'), False), @@ -1140,6 +1150,7 @@ def _temp_path(): generate_decimal_case(), generate_datetime_case(), generate_interval_case(), + generate_map_case(), generate_nested_case(), generate_dictionary_case(), generate_nested_dictionary_case().skip_category(SKIP_ARROW) @@ -1209,12 +1220,10 @@ def _compare_implementations(self, producer, consumer): file_id = guid()[:8] - if (('JS' in (producer.name, consumer.name) or - 'Java' in (producer.name, consumer.name)) and + if ('JS' in (producer.name, consumer.name) and "map" in test_case.name): print('TODO(ARROW-1279): Enable map tests ' + - ' for Java and JS once Java supports them and JS\'' + - ' are unbroken') + ' for JS once they are unbroken') continue if ('JS' in (producer.name, consumer.name) and @@ -1696,6 +1705,7 @@ def run_all_tests(args): def write_js_test_json(directory): + generate_map_case().write(os.path.join(directory, 'map.json')) generate_nested_case().write(os.path.join(directory, 'nested.json')) generate_decimal_case().write(os.path.join(directory, 'decimal.json')) generate_datetime_case().write(os.path.join(directory, 'datetime.json')) diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java index 340fb2ae1ad..a3671501ba7 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/MapVector.java @@ -25,14 +25,17 @@ import org.apache.arrow.vector.AddOrGetResult; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.ZeroVector; import org.apache.arrow.vector.complex.impl.UnionMapReader; import org.apache.arrow.vector.complex.impl.UnionMapWriter; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.Types.MinorType; +import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.ArrowType.Map; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.util.CallBack; +import org.apache.arrow.vector.util.SchemaChangeRuntimeException; /** * A MapVector is used to store entries of key/value pairs. It is a container vector that is @@ -46,6 +49,9 @@ public class MapVector extends ListVector { public static final String KEY_NAME = "key"; public static final String VALUE_NAME = "value"; + // TODO: this is only used for addOrGetVector because ListVector declares it private + protected CallBack callBack; + /** * Construct an empty MapVector with no data. Child vectors must be added subsequently. * @@ -68,6 +74,7 @@ public static MapVector empty(String name, BufferAllocator allocator, boolean ke */ public MapVector(String name, BufferAllocator allocator, FieldType fieldType, CallBack callBack) { super(name, allocator, fieldType, callBack); + this.callBack = callBack; reader = new UnionMapReader(this); } @@ -121,9 +128,29 @@ public UnionMapReader getReader() { */ @Override public AddOrGetResult addOrGetVector(FieldType fieldType) { - AddOrGetResult result = super.addOrGetVector(fieldType); + + // TODO: can call super method once DATA_VECTOR_NAME is configurable + boolean created = false; + if (vector instanceof ZeroVector) { + vector = fieldType.createNewSingleVector("entries", allocator, callBack); + // returned vector must have the same field + created = true; + if (callBack != null && + // not a schema change if changing from ZeroVector to ZeroVector + (fieldType.getType().getTypeID() != ArrowTypeID.Null)) { + callBack.doWork(); + } + } + + if (vector.getField().getType().getTypeID() != fieldType.getType().getTypeID()) { + final String msg = String.format("Inner vector type mismatch. Requested type: [%s], actual type: [%s]", + fieldType.getType().getTypeID(), vector.getField().getType().getTypeID()); + throw new SchemaChangeRuntimeException(msg); + } + reader = new UnionMapReader(this); - return result; + + return new AddOrGetResult<>((T) vector, created); } /** From 6259a4574dcd6783e60cd7dadef2053060ba18ad Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Tue, 25 Jun 2019 15:34:31 -0700 Subject: [PATCH 3/3] rename MapType fields for internal test --- cpp/src/arrow/ipc/json-internal.cc | 4 ++-- cpp/src/arrow/ipc/json-test.cc | 2 +- cpp/src/arrow/ipc/test-common.cc | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 42663c0178d..135296551c9 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -1274,8 +1274,8 @@ class ArrayReader { Status Visit(const MapType& type) { auto list_type = std::make_shared(field( - "item", - struct_({field("key", type.key_type(), false), field("item", type.item_type())}), + "entries", + struct_({field("key", type.key_type(), false), field("value", type.item_type())}), false)); std::shared_ptr list_array; RETURN_NOT_OK(CreateList(list_type, &list_array)); diff --git a/cpp/src/arrow/ipc/json-test.cc b/cpp/src/arrow/ipc/json-test.cc index fb57fa7f52e..338552dd575 100644 --- a/cpp/src/arrow/ipc/json-test.cc +++ b/cpp/src/arrow/ipc/json-test.cc @@ -204,7 +204,7 @@ TEST(TestJsonArrayWriter, NestedTypes) { TestArrayRoundTrip(list_array); - // List + // Map auto map_type = map(utf8(), int32()); auto keys_array = ArrayFromJSON(utf8(), R"(["a", "b", "c", "d", "a", "b", "c"])"); diff --git a/cpp/src/arrow/ipc/test-common.cc b/cpp/src/arrow/ipc/test-common.cc index 12adebc2516..47c307659f0 100644 --- a/cpp/src/arrow/ipc/test-common.cc +++ b/cpp/src/arrow/ipc/test-common.cc @@ -120,7 +120,7 @@ Status MakeRandomMapArray(const std::shared_ptr& key_array, bool include_nulls, MemoryPool* pool, std::shared_ptr* out) { auto pair_type = struct_( - {field("key", key_array->type(), false), field("item", item_array->type())}); + {field("key", key_array->type(), false), field("value", item_array->type())}); auto pair_array = std::make_shared(pair_type, num_maps, ArrayVector{key_array, item_array});