diff --git a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml index ecae656aeb2f..9de15ac8e94d 100644 --- a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml +++ b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml @@ -97,6 +97,7 @@ examples: "\u000A": 10 "\u00c8\u0001": 200 "\u00e8\u0007": 1000 + "\u00a9\u0046": 9001 "\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u00ff\u0001": -1 --- @@ -275,3 +276,25 @@ examples: "\u007f\u00f0\0\0\0\0\0\0": "Infinity" "\u00ff\u00f0\0\0\0\0\0\0": "-Infinity" "\u007f\u00f8\0\0\0\0\0\0": "NaN" + +--- + +coder: + urn: "beam:coder:row:v1" + # str: string, i32: int32, f64: float64, arr: array[string] + payload: "\n\t\n\x03str\x1a\x02\x10\x07\n\t\n\x03i32\x1a\x02\x10\x03\n\t\n\x03f64\x1a\x02\x10\x06\n\r\n\x03arr\x1a\x06\x1a\x04\n\x02\x10\x07\x12$4e5e554c-d4c1-4a5d-b5e1-f3293a6b9f05" +nested: false +examples: + "\u0004\u0000\u0003foo\u00a9\u0046\u003f\u00b9\u0099\u0099\u0099\u0099\u0099\u009a\0\0\0\u0003\u0003foo\u0003bar\u0003baz": {str: "foo", i32: 9001, f64: "0.1", arr: ["foo", "bar", "baz"]} + +--- + +coder: + urn: "beam:coder:row:v1" + # str: nullable string, i32: nullable int32, f64: nullable float64 + payload: "\n\x0b\n\x03str\x1a\x04\x08\x01\x10\x07\n\x0b\n\x03i32\x1a\x04\x08\x01\x10\x03\n\x0b\n\x03f64\x1a\x04\x08\x01\x10\x06\x12$b20c6545-57af-4bc8-b2a9-51ace21c7393" +nested: false +examples: + "\u0003\u0001\u0007": {str: null, i32: null, f64: null} + "\u0003\u0001\u0004\u0003foo\u00a9\u0046": {str: "foo", i32: 9001, f64: null} + "\u0003\u0000\u0003foo\u00a9\u0046\u003f\u00b9\u0099\u0099\u0099\u0099\u0099\u009a": {str: "foo", i32: 9001, f64: "0.1"} diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index ec05ef045b6a..90f52fc6c776 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -645,6 +645,50 @@ message StandardCoders { // Components: Coder for a single element. // Experimental. STATE_BACKED_ITERABLE = 9 [(beam_urn) = "beam:coder:state_backed_iterable:v1"]; + + // Additional Standard Coders + // -------------------------- + // The following coders are not required to be implemented for an SDK or + // runner to support the Beam model, but enable users to take advantage of + // schema-aware transforms. + + // Encodes a "row", an element with a known schema, defined by an + // instance of Schema from schema.proto. + // + // A row is encoded as the concatenation of: + // - The number of attributes in the schema, encoded with + // beam:coder:varint:v1. This makes it possible to detect certain + // allowed schema changes (appending or removing columns) in + // long-running streaming pipelines. + // - A byte array representing a packed bitset indicating null fields (a + // 1 indicating a null) encoded with beam:coder:bytes:v1. The unused + // bits in the last byte must be set to 0. If there are no nulls an + // empty byte array is encoded. + // The two-byte bitset (not including the lenghth-prefix) for the row + // [NULL, 0, 0, 0, NULL, 0, 0, NULL, 0, NULL] would be + // [0b10010001, 0b00000010] + // - An encoding for each non-null field, concatenated together. + // + // Schema types are mapped to coders as follows: + // AtomicType: + // BYTE: not yet a standard coder (BEAM-7996) + // INT16: not yet a standard coder (BEAM-7996) + // INT32: beam:coder:varint:v1 + // INT64: beam:coder:varint:v1 + // FLOAT: not yet a standard coder (BEAM-7996) + // DOUBLE: beam:coder:double:v1 + // STRING: beam:coder:string_utf8:v1 + // BOOLEAN: beam:coder:bool:v1 + // BYTES: beam:coder:bytes:v1 + // ArrayType: beam:coder:iterable:v1 (always has a known length) + // MapType: not yet a standard coder (BEAM-7996) + // RowType: beam:coder:row:v1 + // LogicalType: Uses the coder for its representation. + // + // The payload for RowCoder is an instance of Schema. + // Components: None + // Experimental. + ROW = 13 [(beam_urn) = "beam:coder:row:v1"]; } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java index 9c4e232b6792..f2cc8fa5433e 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java @@ -17,16 +17,22 @@ */ package org.apache.beam.runners.core.construction; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument; + import java.util.Collections; import java.util.List; +import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.RowCoder; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.util.InstanceBuilder; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; /** {@link CoderTranslator} implementations for known coder types. */ @@ -118,6 +124,33 @@ public FullWindowedValueCoder fromComponents(List> components) { }; } + static CoderTranslator row() { + return new CoderTranslator() { + @Override + public List> getComponents(RowCoder from) { + return ImmutableList.of(); + } + + @Override + public byte[] getPayload(RowCoder from) { + return SchemaTranslation.schemaToProto(from.getSchema()).toByteArray(); + } + + @Override + public RowCoder fromComponents(List> components, byte[] payload) { + checkArgument( + components.isEmpty(), "Expected empty component list, but received: " + components); + Schema schema; + try { + schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(payload)); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Unable to parse schema for RowCoder: ", e); + } + return RowCoder.of(schema); + } + }; + } + public abstract static class SimpleStructuredCoderTranslator> implements CoderTranslator { @Override diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java index 8294fe001c4a..854f5235b5f4 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; @@ -60,6 +61,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar { .put(GlobalWindow.Coder.class, ModelCoders.GLOBAL_WINDOW_CODER_URN) .put(FullWindowedValueCoder.class, ModelCoders.WINDOWED_VALUE_CODER_URN) .put(DoubleCoder.class, ModelCoders.DOUBLE_CODER_URN) + .put(RowCoder.class, ModelCoders.ROW_CODER_URN) .build(); public static final Set WELL_KNOWN_CODER_URNS = BEAM_MODEL_CODER_URNS.values(); @@ -79,6 +81,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar { .put(LengthPrefixCoder.class, CoderTranslators.lengthPrefix()) .put(FullWindowedValueCoder.class, CoderTranslators.fullWindowedValue()) .put(DoubleCoder.class, CoderTranslators.atomic(DoubleCoder.class)) + .put(RowCoder.class, CoderTranslators.row()) .build(); static { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java index 8d1265ccdc59..486e39c3a127 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java @@ -54,6 +54,8 @@ private ModelCoders() {} public static final String WINDOWED_VALUE_CODER_URN = getUrn(StandardCoders.Enum.WINDOWED_VALUE); + public static final String ROW_CODER_URN = getUrn(StandardCoders.Enum.ROW); + private static final Set MODEL_CODER_URNS = ImmutableSet.of( BYTES_CODER_URN, @@ -67,7 +69,8 @@ private ModelCoders() {} GLOBAL_WINDOW_CODER_URN, INTERVAL_WINDOW_CODER_URN, WINDOWED_VALUE_CODER_URN, - DOUBLE_CODER_URN); + DOUBLE_CODER_URN, + ROW_CODER_URN); public static Set urns() { return MODEL_CODER_URNS; diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java index dc28b7973c9a..a6368aa5cc50 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java @@ -41,10 +41,14 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; -import org.apache.beam.sdk.coders.StructuredCoder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.schemas.LogicalTypes; +import org.apache.beam.sdk.schemas.Schema; +import org.apache.beam.sdk.schemas.Schema.Field; +import org.apache.beam.sdk.schemas.Schema.FieldType; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow.IntervalWindowCoder; import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder; @@ -60,8 +64,8 @@ /** Tests for {@link CoderTranslation}. */ public class CoderTranslationTest { - private static final Set> KNOWN_CODERS = - ImmutableSet.>builder() + private static final Set> KNOWN_CODERS = + ImmutableSet.>builder() .add(ByteArrayCoder.of()) .add(BooleanCoder.of()) .add(KvCoder.of(VarLongCoder.of(), VarLongCoder.of())) @@ -76,6 +80,13 @@ public class CoderTranslationTest { FullWindowedValueCoder.of( IterableCoder.of(VarLongCoder.of()), IntervalWindowCoder.of())) .add(DoubleCoder.of()) + .add( + RowCoder.of( + Schema.of( + Field.of("i16", FieldType.INT16), + Field.of("array", FieldType.array(FieldType.STRING)), + Field.of("map", FieldType.map(FieldType.STRING, FieldType.INT32)), + Field.of("bar", FieldType.logicalType(LogicalTypes.FixedBytes.of(123)))))) .build(); /** diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java index 1cedd5df9b42..52dddcc75f8b 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CommonCoderTest.java @@ -20,6 +20,8 @@ import static org.apache.beam.runners.core.construction.BeamUrns.getUrn; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects.firstNonNull; import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkNotNull; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList.toImmutableList; +import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap.toImmutableMap; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItem; import static org.hamcrest.Matchers.instanceOf; @@ -46,8 +48,8 @@ import java.util.Map; import javax.annotation.Nullable; import org.apache.beam.model.pipeline.v1.RunnerApi.StandardCoders; +import org.apache.beam.model.pipeline.v1.SchemaApi; import org.apache.beam.sdk.coders.BooleanCoder; -import org.apache.beam.sdk.coders.ByteArrayCoder; import org.apache.beam.sdk.coders.ByteCoder; import org.apache.beam.sdk.coders.Coder; import org.apache.beam.sdk.coders.Coder.Context; @@ -55,8 +57,10 @@ import org.apache.beam.sdk.coders.DoubleCoder; import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.VarLongCoder; +import org.apache.beam.sdk.schemas.Schema; import org.apache.beam.sdk.transforms.windowing.BoundedWindow; import org.apache.beam.sdk.transforms.windowing.GlobalWindow; import org.apache.beam.sdk.transforms.windowing.IntervalWindow; @@ -65,6 +69,8 @@ import org.apache.beam.sdk.util.CoderUtils; import org.apache.beam.sdk.util.WindowedValue; import org.apache.beam.sdk.values.KV; +import org.apache.beam.sdk.values.Row; +import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.MoreObjects; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Splitter; import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList; @@ -99,6 +105,7 @@ public class CommonCoderTest { .put( getUrn(StandardCoders.Enum.WINDOWED_VALUE), WindowedValue.FullWindowedValueCoder.class) + .put(getUrn(StandardCoders.Enum.ROW), RowCoder.class) .build(); @AutoValue @@ -107,16 +114,21 @@ abstract static class CommonCoder { abstract List getComponents(); + @SuppressWarnings("mutable") + abstract byte[] getPayload(); + abstract Boolean getNonDeterministic(); @JsonCreator static CommonCoder create( @JsonProperty("urn") String urn, @JsonProperty("components") @Nullable List components, + @JsonProperty("payload") @Nullable String payload, @JsonProperty("non_deterministic") @Nullable Boolean nonDeterministic) { return new AutoValue_CommonCoderTest_CommonCoder( checkNotNull(urn, "urn"), firstNonNull(components, Collections.emptyList()), + firstNonNull(payload, "").getBytes(StandardCharsets.ISO_8859_1), firstNonNull(nonDeterministic, Boolean.FALSE)); } } @@ -282,43 +294,90 @@ private static Object convertValue(Object value, CommonCoder coderSpec, Coder co return WindowedValue.of(windowValue, timestamp, windows, paneInfo); } else if (s.equals(getUrn(StandardCoders.Enum.DOUBLE))) { return Double.parseDouble((String) value); + } else if (s.equals(getUrn(StandardCoders.Enum.ROW))) { + Schema schema; + try { + schema = SchemaTranslation.fromProto(SchemaApi.Schema.parseFrom(coderSpec.getPayload())); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException("Failed to parse schema payload for row coder", e); + } + + return parseField(value, Schema.FieldType.row(schema)); } else { throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn()); } } + private static Object parseField(Object value, Schema.FieldType fieldType) { + switch (fieldType.getTypeName()) { + case BYTE: + return ((Number) value).byteValue(); + case INT16: + return ((Number) value).shortValue(); + case INT32: + return ((Number) value).intValue(); + case INT64: + return ((Number) value).longValue(); + case FLOAT: + return Float.parseFloat((String) value); + case DOUBLE: + return Double.parseDouble((String) value); + case STRING: + return (String) value; + case BOOLEAN: + return (Boolean) value; + case BYTES: + // extract String as byte[] + return ((String) value).getBytes(StandardCharsets.ISO_8859_1); + case ARRAY: + return ((List) value) + .stream() + .map((element) -> parseField(element, fieldType.getCollectionElementType())) + .collect(toImmutableList()); + case MAP: + Map kvMap = (Map) value; + return kvMap.entrySet().stream() + .collect( + toImmutableMap( + (pair) -> parseField(pair.getKey(), fieldType.getMapKeyType()), + (pair) -> parseField(pair.getValue(), fieldType.getMapValueType()))); + case ROW: + Map rowMap = (Map) value; + Schema schema = fieldType.getRowSchema(); + Row.Builder row = Row.withSchema(schema); + for (Schema.Field field : schema.getFields()) { + Object element = rowMap.remove(field.getName()); + if (element != null) { + element = parseField(element, field.getType()); + } + row.addValue(element); + } + + if (!rowMap.isEmpty()) { + throw new IllegalArgumentException( + "Value contains keys that are not in the schema: " + rowMap.keySet()); + } + + return row.build(); + default: // DECIMAL, DATETIME, LOGICAL_TYPE + throw new IllegalArgumentException("Unsupported type name: " + fieldType.getTypeName()); + } + } + private static Coder instantiateCoder(CommonCoder coder) { List> components = new ArrayList<>(); for (CommonCoder innerCoder : coder.getComponents()) { components.add(instantiateCoder(innerCoder)); } - String s = coder.getUrn(); - if (s.equals(getUrn(StandardCoders.Enum.BYTES))) { - return ByteArrayCoder.of(); - } else if (s.equals(getUrn(StandardCoders.Enum.BOOL))) { - return BooleanCoder.of(); - } else if (s.equals(getUrn(StandardCoders.Enum.STRING_UTF8))) { - return StringUtf8Coder.of(); - } else if (s.equals(getUrn(StandardCoders.Enum.KV))) { - return KvCoder.of(components.get(0), components.get(1)); - } else if (s.equals(getUrn(StandardCoders.Enum.VARINT))) { - return VarLongCoder.of(); - } else if (s.equals(getUrn(StandardCoders.Enum.INTERVAL_WINDOW))) { - return IntervalWindowCoder.of(); - } else if (s.equals(getUrn(StandardCoders.Enum.ITERABLE))) { - return IterableCoder.of(components.get(0)); - } else if (s.equals(getUrn(StandardCoders.Enum.TIMER))) { - return Timer.Coder.of(components.get(0)); - } else if (s.equals(getUrn(StandardCoders.Enum.GLOBAL_WINDOW))) { - return GlobalWindow.Coder.INSTANCE; - } else if (s.equals(getUrn(StandardCoders.Enum.WINDOWED_VALUE))) { - return WindowedValue.FullWindowedValueCoder.of( - components.get(0), (Coder) components.get(1)); - } else if (s.equals(getUrn(StandardCoders.Enum.DOUBLE))) { - return DoubleCoder.of(); - } else { - throw new IllegalStateException("Unknown coder URN: " + coder.getUrn()); - } + Class coderType = + ModelCoderRegistrar.BEAM_MODEL_CODER_URNS.inverse().get(coder.getUrn()); + checkNotNull(coderType, "Unknown coder URN: " + coder.getUrn()); + + CoderTranslator translator = ModelCoderRegistrar.BEAM_MODEL_CODERS.get(coderType); + checkNotNull( + translator, "No translator found for common coder class: " + coderType.getSimpleName()); + + return translator.fromComponents(components, coder.getPayload()); } @Test @@ -380,6 +439,8 @@ private void verifyDecodedValue(CommonCoder coder, Object expectedValue, Object } else if (s.equals(getUrn(StandardCoders.Enum.DOUBLE))) { + assertEquals(expectedValue, actualValue); + } else if (s.equals(getUrn(StandardCoders.Enum.ROW))) { assertEquals(expectedValue, actualValue); } else { throw new IllegalStateException("Unknown coder URN: " + coder.getUrn()); diff --git a/sdks/python/apache_beam/coders/__init__.py b/sdks/python/apache_beam/coders/__init__.py index 3192494ebbf1..680f1c725cbb 100644 --- a/sdks/python/apache_beam/coders/__init__.py +++ b/sdks/python/apache_beam/coders/__init__.py @@ -17,4 +17,5 @@ from __future__ import absolute_import from apache_beam.coders.coders import * +from apache_beam.coders.row_coder import * from apache_beam.coders.typecoders import registry diff --git a/sdks/python/apache_beam/coders/row_coder.py b/sdks/python/apache_beam/coders/row_coder.py new file mode 100644 index 000000000000..a259f362b2c1 --- /dev/null +++ b/sdks/python/apache_beam/coders/row_coder.py @@ -0,0 +1,174 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import absolute_import + +import itertools +from array import array + +from apache_beam.coders.coder_impl import StreamCoderImpl +from apache_beam.coders.coders import BytesCoder +from apache_beam.coders.coders import Coder +from apache_beam.coders.coders import FastCoder +from apache_beam.coders.coders import FloatCoder +from apache_beam.coders.coders import IterableCoder +from apache_beam.coders.coders import StrUtf8Coder +from apache_beam.coders.coders import TupleCoder +from apache_beam.coders.coders import VarIntCoder +from apache_beam.portability import common_urns +from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints.schemas import named_tuple_from_schema +from apache_beam.typehints.schemas import named_tuple_to_schema + +__all__ = ["RowCoder"] + + +class RowCoder(FastCoder): + """ Coder for `typing.NamedTuple` instances. + + Implements the beam:coder:row:v1 standard coder spec. + """ + + def __init__(self, schema): + """Initializes a :class:`RowCoder`. + + Args: + schema (apache_beam.portability.api.schema_pb2.Schema): The protobuf + representation of the schema of the data that the RowCoder will be used + to encode/decode. + """ + self.schema = schema + self.components = [ + RowCoder.coder_from_type(field.type) for field in self.schema.fields + ] + + def _create_impl(self): + return RowCoderImpl(self.schema, self.components) + + def is_deterministic(self): + return all(c.is_deterministic() for c in self.components) + + def to_type_hint(self): + return named_tuple_from_schema(self.schema) + + def as_cloud_object(self, coders_context=None): + raise NotImplementedError("as_cloud_object not supported for RowCoder") + + __hash__ = None + + def __eq__(self, other): + return type(self) == type(other) and self.schema == other.schema + + def to_runner_api_parameter(self, unused_context): + return (common_urns.coders.ROW.urn, self.schema, []) + + @Coder.register_urn(common_urns.coders.ROW.urn, schema_pb2.Schema) + def from_runner_api_parameter(payload, components, unused_context): + return RowCoder(payload) + + @staticmethod + def from_type_hint(named_tuple_type, registry): + return RowCoder(named_tuple_to_schema(named_tuple_type)) + + @staticmethod + def coder_from_type(field_type): + type_info = field_type.WhichOneof("type_info") + if type_info == "atomic_type": + if field_type.atomic_type in (schema_pb2.INT32, + schema_pb2.INT64): + return VarIntCoder() + elif field_type.atomic_type == schema_pb2.DOUBLE: + return FloatCoder() + elif field_type.atomic_type == schema_pb2.STRING: + return StrUtf8Coder() + elif type_info == "array_type": + return IterableCoder( + RowCoder.coder_from_type(field_type.array_type.element_type)) + + # The Java SDK supports several more types, but the coders are not yet + # standard, and are not implemented in Python. + raise ValueError( + "Encountered a type that is not currently supported by RowCoder: %s" % + field_type) + + +class RowCoderImpl(StreamCoderImpl): + """For internal use only; no backwards-compatibility guarantees.""" + SIZE_CODER = VarIntCoder().get_impl() + NULL_MARKER_CODER = BytesCoder().get_impl() + + def __init__(self, schema, components): + self.schema = schema + self.constructor = named_tuple_from_schema(schema) + self.components = list(c.get_impl() for c in components) + self.has_nullable_fields = any( + field.type.nullable for field in self.schema.fields) + + def encode_to_stream(self, value, out, nested): + nvals = len(self.schema.fields) + self.SIZE_CODER.encode_to_stream(nvals, out, True) + attrs = [getattr(value, f.name) for f in self.schema.fields] + + words = array('B') + if self.has_nullable_fields: + nulls = list(attr is None for attr in attrs) + if any(nulls): + words = array('B', itertools.repeat(0, (nvals+7)//8)) + for i, is_null in enumerate(nulls): + words[i//8] |= is_null << (i % 8) + + self.NULL_MARKER_CODER.encode_to_stream(words.tostring(), out, True) + + for c, field, attr in zip(self.components, self.schema.fields, attrs): + if attr is None: + if not field.type.nullable: + raise ValueError( + "Attempted to encode null for non-nullable field \"{}\".".format( + field.name)) + continue + c.encode_to_stream(attr, out, True) + + def decode_from_stream(self, in_stream, nested): + nvals = self.SIZE_CODER.decode_from_stream(in_stream, True) + words = array('B') + words.fromstring(self.NULL_MARKER_CODER.decode_from_stream(in_stream, True)) + + if words: + nulls = ((words[i // 8] >> (i % 8)) & 0x01 for i in range(nvals)) + else: + nulls = itertools.repeat(False, nvals) + + # If this coder's schema has more attributes than the encoded value, then + # the schema must have changed. Populate the unencoded fields with nulls. + if len(self.components) > nvals: + nulls = itertools.chain( + nulls, + itertools.repeat(True, len(self.components) - nvals)) + + # Note that if this coder's schema has *fewer* attributes than the encoded + # value, we just need to ignore the additional values, which will occur + # here because we only decode as many values as we have coders for. + return self.constructor(*( + None if is_null else c.decode_from_stream(in_stream, True) + for c, is_null in zip(self.components, nulls))) + + def _make_value_coder(self, nulls=itertools.repeat(False)): + components = [ + component for component, is_null in zip(self.components, nulls) + if not is_null + ] if self.has_nullable_fields else self.components + return TupleCoder(components).get_impl() diff --git a/sdks/python/apache_beam/coders/row_coder_test.py b/sdks/python/apache_beam/coders/row_coder_test.py new file mode 100644 index 000000000000..dbdc5fc99a89 --- /dev/null +++ b/sdks/python/apache_beam/coders/row_coder_test.py @@ -0,0 +1,168 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from __future__ import absolute_import + +import logging +import typing +import unittest +from itertools import chain + +import numpy as np +from past.builtins import unicode + +from apache_beam.coders import RowCoder +from apache_beam.coders.typecoders import registry as coders_registry +from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints.schemas import typing_to_runner_api + +Person = typing.NamedTuple("Person", [ + ("name", unicode), + ("age", np.int32), + ("address", typing.Optional[unicode]), + ("aliases", typing.List[unicode]), +]) + +coders_registry.register_coder(Person, RowCoder) + + +class RowCoderTest(unittest.TestCase): + TEST_CASES = [ + Person("Jon Snow", 23, None, ["crow", "wildling"]), + Person("Daenerys Targaryen", 25, "Westeros", ["Mother of Dragons"]), + Person("Michael Bluth", 30, None, []) + ] + + def test_create_row_coder_from_named_tuple(self): + expected_coder = RowCoder(typing_to_runner_api(Person).row_type.schema) + real_coder = coders_registry.get_coder(Person) + + for test_case in self.TEST_CASES: + self.assertEqual( + expected_coder.encode(test_case), real_coder.encode(test_case)) + + self.assertEqual(test_case, + real_coder.decode(real_coder.encode(test_case))) + + def test_create_row_coder_from_schema(self): + schema = schema_pb2.Schema( + id="person", + fields=[ + schema_pb2.Field( + name="name", + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)), + schema_pb2.Field( + name="age", + type=schema_pb2.FieldType( + atomic_type=schema_pb2.INT32)), + schema_pb2.Field( + name="address", + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING, nullable=True)), + schema_pb2.Field( + name="aliases", + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)))), + ]) + coder = RowCoder(schema) + + for test_case in self.TEST_CASES: + self.assertEqual(test_case, coder.decode(coder.encode(test_case))) + + @unittest.skip( + "BEAM-8030 - Overflow behavior in VarIntCoder is currently inconsistent" + ) + def test_overflows(self): + IntTester = typing.NamedTuple('IntTester', [ + # TODO(BEAM-7996): Test int8 and int16 here as well when those types are + # supported + # ('i8', typing.Optional[np.int8]), + # ('i16', typing.Optional[np.int16]), + ('i32', typing.Optional[np.int32]), + ('i64', typing.Optional[np.int64]), + ]) + + c = RowCoder.from_type_hint(IntTester, None) + + no_overflow = chain( + (IntTester(i32=i, i64=None) for i in (-2**31, 2**31-1)), + (IntTester(i32=None, i64=i) for i in (-2**63, 2**63-1)), + ) + + # Encode max/min ints to make sure they don't throw any error + for case in no_overflow: + c.encode(case) + + overflow = chain( + (IntTester(i32=i, i64=None) for i in (-2**31-1, 2**31)), + (IntTester(i32=None, i64=i) for i in (-2**63-1, 2**63)), + ) + + # Encode max+1/min-1 ints to make sure they DO throw an error + for case in overflow: + self.assertRaises(OverflowError, lambda: c.encode(case)) + + def test_none_in_non_nullable_field_throws(self): + Test = typing.NamedTuple('Test', [('foo', unicode)]) + + c = RowCoder.from_type_hint(Test, None) + self.assertRaises(ValueError, lambda: c.encode(Test(foo=None))) + + def test_schema_remove_column(self): + fields = [("field1", unicode), ("field2", unicode)] + # new schema is missing one field that was in the old schema + Old = typing.NamedTuple('Old', fields) + New = typing.NamedTuple('New', fields[:-1]) + + old_coder = RowCoder.from_type_hint(Old, None) + new_coder = RowCoder.from_type_hint(New, None) + + self.assertEqual( + New("foo"), new_coder.decode(old_coder.encode(Old("foo", "bar")))) + + def test_schema_add_column(self): + fields = [("field1", unicode), ("field2", typing.Optional[unicode])] + # new schema has one (optional) field that didn't exist in the old schema + Old = typing.NamedTuple('Old', fields[:-1]) + New = typing.NamedTuple('New', fields) + + old_coder = RowCoder.from_type_hint(Old, None) + new_coder = RowCoder.from_type_hint(New, None) + + self.assertEqual( + New("bar", None), new_coder.decode(old_coder.encode(Old("bar")))) + + def test_schema_add_column_with_null_value(self): + fields = [("field1", typing.Optional[unicode]), ("field2", unicode), + ("field3", typing.Optional[unicode])] + # new schema has one (optional) field that didn't exist in the old schema + Old = typing.NamedTuple('Old', fields[:-1]) + New = typing.NamedTuple('New', fields) + + old_coder = RowCoder.from_type_hint(Old, None) + new_coder = RowCoder.from_type_hint(New, None) + + self.assertEqual( + New(None, "baz", None), + new_coder.decode(old_coder.encode(Old(None, "baz")))) + + +if __name__ == "__main__": + logging.getLogger().setLevel(logging.INFO) + unittest.main() diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index 606ca811ed87..5ffbeea0517c 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -32,9 +32,11 @@ from apache_beam.coders import coder_impl from apache_beam.portability.api import beam_runner_api_pb2 +from apache_beam.portability.api import schema_pb2 from apache_beam.runners import pipeline_context from apache_beam.transforms import window from apache_beam.transforms.window import IntervalWindow +from apache_beam.typehints import schemas from apache_beam.utils import windowed_value from apache_beam.utils.timestamp import Timestamp @@ -65,6 +67,42 @@ def parse_float(s): return x +def value_parser_from_schema(schema): + def attribute_parser_from_type(type_): + # TODO: This should be exhaustive + type_info = type_.WhichOneof("type_info") + if type_info == "atomic_type": + return schemas.ATOMIC_TYPE_TO_PRIMITIVE[type_.atomic_type] + elif type_info == "array_type": + element_parser = attribute_parser_from_type(type_.array_type.element_type) + return lambda x: list(map(element_parser, x)) + elif type_info == "map_type": + key_parser = attribute_parser_from_type(type_.array_type.key_type) + value_parser = attribute_parser_from_type(type_.array_type.value_type) + return lambda x: dict((key_parser(k), value_parser(v)) + for k, v in x.items()) + + parsers = [(field.name, attribute_parser_from_type(field.type)) + for field in schema.fields] + + constructor = schemas.named_tuple_from_schema(schema) + + def value_parser(x): + result = [] + for name, parser in parsers: + value = x.pop(name) + result.append(None if value is None else parser(value)) + + if len(x): + raise ValueError( + "Test data contains attributes that don't exist in the schema: {}" + .format(', '.join(x.keys()))) + + return constructor(*result) + + return value_parser + + class StandardCodersTest(unittest.TestCase): _urn_to_json_value_parser = { @@ -134,11 +172,17 @@ def parse_coder(self, spec): for c in spec.get('components', ())] context.coders.put_proto(coder_id, beam_runner_api_pb2.Coder( spec=beam_runner_api_pb2.FunctionSpec( - urn=spec['urn'], payload=spec.get('payload')), + urn=spec['urn'], payload=spec.get('payload', '').encode('latin1')), component_coder_ids=component_ids)) return context.coders.get_by_id(coder_id) def json_value_parser(self, coder_spec): + # TODO: integrate this with the logic for the other parsers + if coder_spec['urn'] == 'beam:coder:row:v1': + schema = schema_pb2.Schema.FromString( + coder_spec['payload'].encode('latin1')) + return value_parser_from_schema(schema) + component_parsers = [ self.json_value_parser(c) for c in coder_spec.get('components', ())] return lambda x: self._urn_to_json_value_parser[coder_spec['urn']]( diff --git a/sdks/python/apache_beam/typehints/native_type_compatibility.py b/sdks/python/apache_beam/typehints/native_type_compatibility.py index 43cdedc60163..d73a1cf96e73 100644 --- a/sdks/python/apache_beam/typehints/native_type_compatibility.py +++ b/sdks/python/apache_beam/typehints/native_type_compatibility.py @@ -50,6 +50,17 @@ def _get_compatible_args(typ): return None +def _get_args(typ): + """Returns the index-th argument to the given type.""" + try: + return typ.__args__ + except AttributeError: + compatible_args = _get_compatible_args(typ) + if compatible_args is None: + raise + return compatible_args + + def _get_arg(typ, index): """Returns the index-th argument to the given type.""" try: @@ -105,6 +116,15 @@ def _match_same_type(match_against): return lambda user_type: type(user_type) == type(match_against) +def _match_is_exactly_mapping(user_type): + # Avoid unintentionally catching all subtypes (e.g. strings and mappings). + if sys.version_info < (3, 7): + expected_origin = typing.Mapping + else: + expected_origin = collections.abc.Mapping + return getattr(user_type, '__origin__', None) is expected_origin + + def _match_is_exactly_iterable(user_type): # Avoid unintentionally catching all subtypes (e.g. strings and mappings). if sys.version_info < (3, 7): @@ -119,6 +139,22 @@ def _match_is_named_tuple(user_type): hasattr(user_type, '_field_types')) +def _match_is_optional(user_type): + return _match_is_union(user_type) and sum( + tp is type(None) for tp in _get_args(user_type)) == 1 + + +def extract_optional_type(user_type): + """Extracts the non-None type from Optional type user_type. + + If user_type is not Optional, returns None + """ + if not _match_is_optional(user_type): + return None + else: + return next(tp for tp in _get_args(user_type) if tp is not type(None)) + + def _match_is_union(user_type): # For non-subscripted unions (Python 2.7.14+ with typing 3.64) if user_type is typing.Union: diff --git a/sdks/python/apache_beam/typehints/schemas.py b/sdks/python/apache_beam/typehints/schemas.py new file mode 100644 index 000000000000..812cbe1fc32d --- /dev/null +++ b/sdks/python/apache_beam/typehints/schemas.py @@ -0,0 +1,218 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" Support for mapping python types to proto Schemas and back again. + +Python Schema +np.int8 <-----> BYTE +np.int16 <-----> INT16 +np.int32 <-----> INT32 +np.int64 <-----> INT64 +int ---/ +np.float32 <-----> FLOAT +np.float64 <-----> DOUBLE +float ---/ +bool <-----> BOOLEAN + +The mappings for STRING and BYTES are different between python 2 and python 3, +because of the changes to str: +py3: +str/unicode <-----> STRING +bytes <-----> BYTES +ByteString ---/ + +py2: +str will be rejected since it is ambiguous. +unicode <-----> STRING +ByteString <-----> BYTES +""" + +from __future__ import absolute_import + +import sys +from typing import ByteString +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from uuid import uuid4 + +import numpy as np +from past.builtins import unicode + +from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints.native_type_compatibility import _get_args +from apache_beam.typehints.native_type_compatibility import _match_is_exactly_mapping +from apache_beam.typehints.native_type_compatibility import _match_is_named_tuple +from apache_beam.typehints.native_type_compatibility import _match_is_optional +from apache_beam.typehints.native_type_compatibility import _safe_issubclass +from apache_beam.typehints.native_type_compatibility import extract_optional_type + + +# Registry of typings for a schema by UUID +class SchemaTypeRegistry(object): + def __init__(self): + self.by_id = {} + self.by_typing = {} + + def add(self, typing, schema): + self.by_id[schema.id] = (typing, schema) + + def get_typing_by_id(self, unique_id): + result = self.by_id.get(unique_id, None) + return result[0] if result is not None else None + + def get_schema_by_id(self, unique_id): + result = self.by_id.get(unique_id, None) + return result[1] if result is not None else None + + +SCHEMA_REGISTRY = SchemaTypeRegistry() + + +# Bi-directional mappings +_PRIMITIVES = ( + (np.int8, schema_pb2.BYTE), + (np.int16, schema_pb2.INT16), + (np.int32, schema_pb2.INT32), + (np.int64, schema_pb2.INT64), + (np.float32, schema_pb2.FLOAT), + (np.float64, schema_pb2.DOUBLE), + (unicode, schema_pb2.STRING), + (bool, schema_pb2.BOOLEAN), + (bytes if sys.version_info.major >= 3 else ByteString, + schema_pb2.BYTES), +) + +PRIMITIVE_TO_ATOMIC_TYPE = dict((typ, atomic) for typ, atomic in _PRIMITIVES) +ATOMIC_TYPE_TO_PRIMITIVE = dict((atomic, typ) for typ, atomic in _PRIMITIVES) + +# One-way mappings +PRIMITIVE_TO_ATOMIC_TYPE.update({ + # In python 2, this is a no-op because we define it as the bi-directional + # mapping above. This just ensures the one-way mapping is defined in python + # 3. + ByteString: schema_pb2.BYTES, + # Allow users to specify a native int, and use INT64 as the cross-language + # representation. Technically ints have unlimited precision, but RowCoder + # should throw an error if it sees one with a bit width > 64 when encoding. + int: schema_pb2.INT64, + float: schema_pb2.DOUBLE, +}) + + +def typing_to_runner_api(type_): + if _match_is_named_tuple(type_): + schema = None + if hasattr(type_, 'id'): + schema = SCHEMA_REGISTRY.get_schema_by_id(type_.id) + if schema is None: + fields = [ + schema_pb2.Field( + name=name, type=typing_to_runner_api(type_._field_types[name])) + for name in type_._fields + ] + type_id = str(uuid4()) + schema = schema_pb2.Schema(fields=fields, id=type_id) + SCHEMA_REGISTRY.add(type_, schema) + + return schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema)) + + # All concrete types (other than NamedTuple sub-classes) should map to + # a supported primitive type. + elif type_ in PRIMITIVE_TO_ATOMIC_TYPE: + return schema_pb2.FieldType(atomic_type=PRIMITIVE_TO_ATOMIC_TYPE[type_]) + + elif sys.version_info.major == 2 and type_ == str: + raise ValueError( + "type 'str' is not supported in python 2. Please use 'unicode' or " + "'typing.ByteString' instead to unambiguously indicate if this is a " + "UTF-8 string or a byte array." + ) + + elif _match_is_exactly_mapping(type_): + key_type, value_type = map(typing_to_runner_api, _get_args(type_)) + return schema_pb2.FieldType( + map_type=schema_pb2.MapType(key_type=key_type, value_type=value_type)) + + elif _match_is_optional(type_): + # It's possible that a user passes us Optional[Optional[T]], but in python + # typing this is indistinguishable from Optional[T] - both resolve to + # Union[T, None] - so there's no need to check for that case here. + result = typing_to_runner_api(extract_optional_type(type_)) + result.nullable = True + return result + + elif _safe_issubclass(type_, Sequence): + element_type = typing_to_runner_api(_get_args(type_)[0]) + return schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=element_type)) + + raise ValueError("Unsupported type: %s" % type_) + + +def typing_from_runner_api(fieldtype_proto): + if fieldtype_proto.nullable: + # In order to determine the inner type, create a copy of fieldtype_proto + # with nullable=False and pass back to typing_from_runner_api + base_type = schema_pb2.FieldType() + base_type.CopyFrom(fieldtype_proto) + base_type.nullable = False + return Optional[typing_from_runner_api(base_type)] + + type_info = fieldtype_proto.WhichOneof("type_info") + if type_info == "atomic_type": + try: + return ATOMIC_TYPE_TO_PRIMITIVE[fieldtype_proto.atomic_type] + except KeyError: + raise ValueError("Unsupported atomic type: {0}".format( + fieldtype_proto.atomic_type)) + elif type_info == "array_type": + return Sequence[typing_from_runner_api( + fieldtype_proto.array_type.element_type)] + elif type_info == "map_type": + return Mapping[ + typing_from_runner_api(fieldtype_proto.map_type.key_type), + typing_from_runner_api(fieldtype_proto.map_type.value_type) + ] + elif type_info == "row_type": + schema = fieldtype_proto.row_type.schema + user_type = SCHEMA_REGISTRY.get_typing_by_id(schema.id) + if user_type is None: + from apache_beam import coders + type_name = 'BeamSchema_{}'.format(schema.id.replace('-', '_')) + user_type = NamedTuple(type_name, + [(field.name, typing_from_runner_api(field.type)) + for field in schema.fields]) + user_type.id = schema.id + SCHEMA_REGISTRY.add(user_type, schema) + coders.registry.register_coder(user_type, coders.RowCoder) + return user_type + + elif type_info == "logical_type": + pass # TODO + + +def named_tuple_from_schema(schema): + return typing_from_runner_api( + schema_pb2.FieldType(row_type=schema_pb2.RowType(schema=schema))) + + +def named_tuple_to_schema(named_tuple): + return typing_to_runner_api(named_tuple).row_type.schema diff --git a/sdks/python/apache_beam/typehints/schemas_test.py b/sdks/python/apache_beam/typehints/schemas_test.py new file mode 100644 index 000000000000..9dd1bc22ad1c --- /dev/null +++ b/sdks/python/apache_beam/typehints/schemas_test.py @@ -0,0 +1,270 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for schemas.""" + +from __future__ import absolute_import + +import itertools +import sys +import unittest +from typing import ByteString +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence + +import numpy as np +from past.builtins import unicode + +from apache_beam.portability.api import schema_pb2 +from apache_beam.typehints.schemas import typing_from_runner_api +from apache_beam.typehints.schemas import typing_to_runner_api + +IS_PYTHON_3 = sys.version_info.major > 2 + + +class SchemaTest(unittest.TestCase): + """ Tests for Runner API Schema proto to/from typing conversions + + There are two main tests: test_typing_survives_proto_roundtrip, and + test_proto_survives_typing_roundtrip. These are both necessary because Schemas + are cached by ID, so performing just one of them wouldn't necessarily exercise + all code paths. + """ + + def test_typing_survives_proto_roundtrip(self): + all_nonoptional_primitives = [ + np.int8, + np.int16, + np.int32, + np.int64, + np.float32, + np.float64, + unicode, + bool, + ] + + # The bytes type cannot survive a roundtrip to/from proto in Python 2. + # In order to use BYTES a user type has to use typing.ByteString (because + # bytes == str, and we map str to STRING). + if IS_PYTHON_3: + all_nonoptional_primitives.extend([bytes]) + + all_optional_primitives = [ + Optional[typ] for typ in all_nonoptional_primitives + ] + + all_primitives = all_nonoptional_primitives + all_optional_primitives + + basic_array_types = [Sequence[typ] for typ in all_primitives] + + basic_map_types = [ + Mapping[key_type, + value_type] for key_type, value_type in itertools.product( + all_primitives, all_primitives) + ] + + selected_schemas = [ + NamedTuple( + 'AllPrimitives', + [('field%d' % i, typ) for i, typ in enumerate(all_primitives)]), + NamedTuple('ComplexSchema', [ + ('id', np.int64), + ('name', unicode), + ('optional_map', Optional[Mapping[unicode, + Optional[np.float64]]]), + ('optional_array', Optional[Sequence[np.float32]]), + ('array_optional', Sequence[Optional[bool]]), + ]) + ] + + test_cases = all_primitives + \ + basic_array_types + \ + basic_map_types + \ + selected_schemas + + for test_case in test_cases: + self.assertEqual(test_case, + typing_from_runner_api(typing_to_runner_api(test_case))) + + def test_proto_survives_typing_roundtrip(self): + all_nonoptional_primitives = [ + schema_pb2.FieldType(atomic_type=typ) + for typ in schema_pb2.AtomicType.values() + if typ is not schema_pb2.UNSPECIFIED + ] + + # The bytes type cannot survive a roundtrip to/from proto in Python 2. + # In order to use BYTES a user type has to use typing.ByteString (because + # bytes == str, and we map str to STRING). + if not IS_PYTHON_3: + all_nonoptional_primitives.remove( + schema_pb2.FieldType(atomic_type=schema_pb2.BYTES)) + + all_optional_primitives = [ + schema_pb2.FieldType(nullable=True, atomic_type=typ) + for typ in schema_pb2.AtomicType.values() + if typ is not schema_pb2.UNSPECIFIED + ] + + all_primitives = all_nonoptional_primitives + all_optional_primitives + + basic_array_types = [ + schema_pb2.FieldType(array_type=schema_pb2.ArrayType(element_type=typ)) + for typ in all_primitives + ] + + basic_map_types = [ + schema_pb2.FieldType( + map_type=schema_pb2.MapType( + key_type=key_type, value_type=value_type)) for key_type, + value_type in itertools.product(all_primitives, all_primitives) + ] + + selected_schemas = [ + schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + id='32497414-85e8-46b7-9c90-9a9cc62fe390', + fields=[ + schema_pb2.Field(name='field%d' % i, type=typ) + for i, typ in enumerate(all_primitives) + ]))), + schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + id='dead1637-3204-4bcb-acf8-99675f338600', + fields=[ + schema_pb2.Field( + name='id', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)), + schema_pb2.Field( + name='optional_map', + type=schema_pb2.FieldType( + nullable=True, + map_type=schema_pb2.MapType( + key_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING + ), + value_type=schema_pb2.FieldType( + atomic_type=schema_pb2.DOUBLE + )))), + schema_pb2.Field( + name='optional_array', + type=schema_pb2.FieldType( + nullable=True, + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.FLOAT) + ))), + schema_pb2.Field( + name='array_optional', + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + nullable=True, + atomic_type=schema_pb2.BYTES) + ))), + ]))), + ] + + test_cases = all_primitives + \ + basic_array_types + \ + basic_map_types + \ + selected_schemas + + for test_case in test_cases: + self.assertEqual(test_case, + typing_to_runner_api(typing_from_runner_api(test_case))) + + def test_unknown_primitive_raise_valueerror(self): + self.assertRaises(ValueError, lambda: typing_to_runner_api(np.uint32)) + + def test_unknown_atomic_raise_valueerror(self): + self.assertRaises( + ValueError, lambda: typing_from_runner_api( + schema_pb2.FieldType(atomic_type=schema_pb2.UNSPECIFIED)) + ) + + @unittest.skipIf(IS_PYTHON_3, 'str is acceptable in python 3') + def test_str_raises_error_py2(self): + self.assertRaises(lambda: typing_to_runner_api(str)) + self.assertRaises(lambda: typing_to_runner_api( + NamedTuple('Test', [('int', int), ('str', str)]))) + + def test_int_maps_to_int64(self): + self.assertEqual( + schema_pb2.FieldType(atomic_type=schema_pb2.INT64), + typing_to_runner_api(int)) + + def test_float_maps_to_float64(self): + self.assertEqual( + schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE), + typing_to_runner_api(float)) + + def test_trivial_example(self): + MyCuteClass = NamedTuple('MyCuteClass', [ + ('name', unicode), + ('age', Optional[int]), + ('interests', List[unicode]), + ('height', float), + ('blob', ByteString), + ]) + + expected = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema(fields=[ + schema_pb2.Field( + name='name', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING), + ), + schema_pb2.Field( + name='age', + type=schema_pb2.FieldType( + nullable=True, + atomic_type=schema_pb2.INT64)), + schema_pb2.Field( + name='interests', + type=schema_pb2.FieldType( + array_type=schema_pb2.ArrayType( + element_type=schema_pb2.FieldType( + atomic_type=schema_pb2.STRING)))), + schema_pb2.Field( + name='height', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.DOUBLE)), + schema_pb2.Field( + name='blob', + type=schema_pb2.FieldType( + atomic_type=schema_pb2.BYTES)), + ]))) + + # Only test that the fields are equal. If we attempt to test the entire type + # or the entire schema, the generated id will break equality. + self.assertEqual(expected.row_type.schema.fields, + typing_to_runner_api(MyCuteClass).row_type.schema.fields) + + +if __name__ == '__main__': + unittest.main() diff --git a/sdks/python/scripts/generate_pydoc.sh b/sdks/python/scripts/generate_pydoc.sh index eab8aadaf895..e3794bae5c22 100755 --- a/sdks/python/scripts/generate_pydoc.sh +++ b/sdks/python/scripts/generate_pydoc.sh @@ -157,6 +157,7 @@ ignore_identifiers = [ 'apache_beam.metrics.metric.MetricResults', 'apache_beam.pipeline.PipelineVisitor', 'apache_beam.pipeline.PTransformOverride', + 'apache_beam.portability.api.schema_pb2.Schema', 'apache_beam.pvalue.AsSideInput', 'apache_beam.pvalue.DoOutputsTuple', 'apache_beam.pvalue.PValue', diff --git a/sdks/python/setup.py b/sdks/python/setup.py index 1cbc27f034ad..d8d5c7ddcaff 100644 --- a/sdks/python/setup.py +++ b/sdks/python/setup.py @@ -116,6 +116,7 @@ def get_version(): 'hdfs>=2.1.0,<3.0.0', 'httplib2>=0.8,<=0.12.0', 'mock>=1.0.1,<3.0.0', + 'numpy>=1.14.3,<2', 'pymongo>=3.8.0,<4.0.0', 'oauth2client>=2.0.1,<4', 'protobuf>=3.5.0.post1,<4', @@ -139,7 +140,6 @@ def get_version(): REQUIRED_TEST_PACKAGES = [ 'nose>=1.3.7', 'nose_xunitmp>=0.4.1', - 'numpy>=1.14.3,<2', 'pandas>=0.23.4,<0.25', 'parameterized>=0.6.0,<0.7.0', 'pyhamcrest>=1.9,<2.0',