diff --git a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml index 10fcaa5a6ed7..7473f30ae6a2 100644 --- a/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml +++ b/model/fn-execution/src/main/resources/org/apache/beam/model/fnexecution/v1/standard_coders.yaml @@ -569,3 +569,15 @@ coder: examples: "\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0067\u0080\u0000\u0001\u0052\u009a\u00a4\u009b\u0068\u0080\u00dd\u00db\u0001" : {window: {end: 1454293425000, span: 3600000}} "\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0075\u007f\u00df\u003b\u0064\u005a\u001c\u00ad\u0076\u00ed\u0002" : {window: {end: -9223372036854410, span: 365}} + + +--- +coder: + urn: "beam:coder:nullable:v1" + components: [{urn: "beam:coder:bytes:v1"}] +nested: true + +examples: + "\u0001\u0003\u0061\u0062\u0063" : "abc" + "\u0001\u000a\u006d\u006f\u0072\u0065\u0020\u0062\u0079\u0074\u0065\u0073" : "more bytes" + "\u0000" : null diff --git a/model/pipeline/src/main/proto/beam_runner_api.proto b/model/pipeline/src/main/proto/beam_runner_api.proto index 7fdb5aaf5e86..c1e318491f27 100644 --- a/model/pipeline/src/main/proto/beam_runner_api.proto +++ b/model/pipeline/src/main/proto/beam_runner_api.proto @@ -1043,11 +1043,8 @@ message StandardCoders { // - Followed by N interleaved keys and values, encoded with their // corresponding coder. // - // Nullable types in container types (ArrayType, MapType) are encoded by: - // - A one byte null indicator, 0x00 for null values, or 0x01 for present - // values. - // - For present values the null indicator is followed by the value - // encoded with it's corresponding coder. + // Nullable types in container types (ArrayType, MapType) per the + // encoding described for general Nullable types below. // // Well known logical types: // beam:logical_type:micros_instant:v1 @@ -1085,6 +1082,15 @@ message StandardCoders { // Components: the user key coder. // Experimental. SHARDED_KEY = 15 [(beam_urn) = "beam:coder:sharded_key:v1"]; + + // Wraps a coder of a potentially null value + // A Nullable Type is encoded by: + // - A one byte null indicator, 0x00 for null values, or 0x01 for present + // values. + // - For present values the null indicator is followed by the value + // encoded with it's corresponding coder. + // Components: single coder for the value + NULLABLE = 17 [(beam_urn) = "beam:coder:nullable:v1"]; } } diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java index 1838fa692e1b..59d14b608621 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/CoderTranslators.java @@ -27,6 +27,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder; import org.apache.beam.sdk.schemas.Schema; @@ -204,6 +205,22 @@ public List> getComponents(TimestampPrefixingWindowCoder f }; } + static CoderTranslator> nullable() { + return new SimpleStructuredCoderTranslator>() { + @Override + protected NullableCoder fromComponents(List> components) { + checkArgument( + components.size() == 1, "Expected one component, but received: " + components); + return NullableCoder.of(components.get(0)); + } + + @Override + public List> getComponents(NullableCoder from) { + return from.getComponents(); + } + }; + } + public abstract static class SimpleStructuredCoderTranslator> implements CoderTranslator { @Override diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java index e0cc8dc11b94..eb94476d2b7d 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoderRegistrar.java @@ -29,6 +29,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder; @@ -73,6 +74,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar { .put(RowCoder.class, ModelCoders.ROW_CODER_URN) .put(ShardedKey.Coder.class, ModelCoders.SHARDED_KEY_CODER_URN) .put(TimestampPrefixingWindowCoder.class, ModelCoders.CUSTOM_WINDOW_CODER_URN) + .put(NullableCoder.class, ModelCoders.NULLABLE_CODER_URN) .build(); private static final Map, CoderTranslator> @@ -96,6 +98,7 @@ public class ModelCoderRegistrar implements CoderTranslatorRegistrar { .put(RowCoder.class, CoderTranslators.row()) .put(ShardedKey.Coder.class, CoderTranslators.shardedKey()) .put(TimestampPrefixingWindowCoder.class, CoderTranslators.timestampPrefixingWindow()) + .put(NullableCoder.class, CoderTranslators.nullable()) .build(); static { diff --git a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java index b616ffab462e..bc0ec755f4cc 100644 --- a/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java +++ b/runners/core-construction-java/src/main/java/org/apache/beam/runners/core/construction/ModelCoders.java @@ -67,6 +67,8 @@ private ModelCoders() {} public static final String SHARDED_KEY_CODER_URN = getUrn(StandardCoders.Enum.SHARDED_KEY); + public static final String NULLABLE_CODER_URN = getUrn(StandardCoders.Enum.NULLABLE); + static { checkState( STATE_BACKED_ITERABLE_CODER_URN.equals(getUrn(StandardCoders.Enum.STATE_BACKED_ITERABLE))); @@ -90,7 +92,8 @@ private ModelCoders() {} ROW_CODER_URN, PARAM_WINDOWED_VALUE_CODER_URN, STATE_BACKED_ITERABLE_CODER_URN, - SHARDED_KEY_CODER_URN); + SHARDED_KEY_CODER_URN, + NULLABLE_CODER_URN); public static Set urns() { return MODEL_CODER_URNS; diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java index 544f43d0f0f4..f759ebede63c 100644 --- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java +++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/CoderTranslationTest.java @@ -42,6 +42,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.KvCoder; import org.apache.beam.sdk.coders.LengthPrefixCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.SerializableCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; @@ -97,6 +98,7 @@ public class CoderTranslationTest { Field.of("bar", FieldType.logicalType(FixedBytes.of(123)))))) .add(ShardedKey.Coder.of(StringUtf8Coder.of())) .add(TimestampPrefixingWindowCoder.of(IntervalWindowCoder.of())) + .add(NullableCoder.of(ByteArrayCoder.of())) .build(); /** diff --git a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java index c5f2283e41b0..ca5274a358bf 100644 --- a/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java +++ b/runners/java-fn-execution/src/test/java/org/apache/beam/runners/fnexecution/wire/CommonCoderTest.java @@ -69,6 +69,7 @@ import org.apache.beam.sdk.coders.IterableCoder; import org.apache.beam.sdk.coders.IterableLikeCoder; import org.apache.beam.sdk.coders.KvCoder; +import org.apache.beam.sdk.coders.NullableCoder; import org.apache.beam.sdk.coders.RowCoder; import org.apache.beam.sdk.coders.StringUtf8Coder; import org.apache.beam.sdk.coders.TimestampPrefixingWindowCoder; @@ -134,6 +135,7 @@ public class CommonCoderTest { .put(getUrn(StandardCoders.Enum.SHARDED_KEY), ShardedKey.Coder.class) .put(getUrn(StandardCoders.Enum.CUSTOM_WINDOW), TimestampPrefixingWindowCoder.class) .put(getUrn(StandardCoders.Enum.STATE_BACKED_ITERABLE), StateBackedIterable.Coder.class) + .put(getUrn(StandardCoders.Enum.NULLABLE), NullableCoder.class) .build(); @AutoValue @@ -201,7 +203,7 @@ abstract static class OneCoderTestSpec { @SuppressWarnings("mutable") abstract byte[] getSerialized(); - abstract Object getValue(); + abstract @Nullable Object getValue(); static OneCoderTestSpec create( CommonCoder coder, boolean nested, byte[] serialized, Object value) { @@ -382,6 +384,17 @@ private static Object convertValue(Object value, CommonCoder coderSpec, Coder co Map kvMap = (Map) value; Coder windowCoder = ((TimestampPrefixingWindowCoder) coder).getWindowCoder(); return convertValue(kvMap.get("window"), coderSpec.getComponents().get(0), windowCoder); + } else if (s.equals(getUrn(StandardCoders.Enum.NULLABLE))) { + if (coderSpec.getComponents().size() == 1 + && coderSpec.getComponents().get(0).getUrn().equals(getUrn(StandardCoders.Enum.BYTES))) { + if (value == null) { + return null; + } else { + return ((String) value).getBytes(StandardCharsets.ISO_8859_1); + } + } else { + throw new IllegalStateException("Unknown or missing nested coder for nullable coder"); + } } else { throw new IllegalStateException("Unknown coder URN: " + coderSpec.getUrn()); } @@ -575,6 +588,8 @@ private void verifyDecodedValue(CommonCoder coder, Object expectedValue, Object assertEquals(expectedValue, actualValue); } else if (s.equals(getUrn(StandardCoders.Enum.CUSTOM_WINDOW))) { assertEquals(expectedValue, actualValue); + } else if (s.equals(getUrn(StandardCoders.Enum.NULLABLE))) { + assertThat(expectedValue, equalTo(actualValue)); } else { throw new IllegalStateException("Unknown coder URN: " + coder.getUrn()); } diff --git a/sdks/go/pkg/beam/core/graph/coder/coder.go b/sdks/go/pkg/beam/core/graph/coder/coder.go index 6eea66b0d317..8424f7af8758 100644 --- a/sdks/go/pkg/beam/core/graph/coder/coder.go +++ b/sdks/go/pkg/beam/core/graph/coder/coder.go @@ -169,6 +169,7 @@ const ( VarInt Kind = "varint" Double Kind = "double" Row Kind = "R" + Nullable Kind = "N" Timer Kind = "T" PaneInfo Kind = "PI" WindowedValue Kind = "W" @@ -198,7 +199,7 @@ type Coder struct { Kind Kind T typex.FullType - Components []*Coder // WindowedValue, KV, CoGBK + Components []*Coder // WindowedValue, KV, CoGBK, Nullable Custom *CustomCoder // Custom Window *WindowCoder // WindowedValue @@ -260,7 +261,7 @@ func (c *Coder) String() string { switch c.Kind { case WindowedValue, ParamWindowedValue, Window, Timer: ret += fmt.Sprintf("!%v", c.Window) - case KV, CoGBK, Bytes, Bool, VarInt, Double, String, LP: // No additional info. + case KV, CoGBK, Bytes, Bool, VarInt, Double, String, LP, Nullable: // No additional info. default: ret += fmt.Sprintf("[%v]", c.T) } @@ -394,6 +395,20 @@ func NewKV(components []*Coder) *Coder { } } +func NewN(component *Coder) *Coder { + coders := []*Coder{component} + checkCodersNotNil(coders) + return &Coder{ + Kind: Nullable, + T: typex.New(typex.NullableType, component.T), + Components: coders, + } +} + +func IsNullable(c *Coder) bool { + return c.Kind == Nullable +} + // IsCoGBK returns true iff the coder is for a CoGBK type. func IsCoGBK(c *Coder) bool { return c.Kind == CoGBK diff --git a/sdks/go/pkg/beam/core/graph/coder/coder_test.go b/sdks/go/pkg/beam/core/graph/coder/coder_test.go index 762ed848f589..44606dc1efb2 100644 --- a/sdks/go/pkg/beam/core/graph/coder/coder_test.go +++ b/sdks/go/pkg/beam/core/graph/coder/coder_test.go @@ -168,6 +168,9 @@ func TestCoder_String(t *testing.T) { }, { want: "KV", c: NewKV([]*Coder{bytes, ints}), + }, { + want: "N", + c: NewN(bytes), }, { want: "CoGBK", c: NewCoGBK([]*Coder{bytes, ints, bytes}), @@ -277,6 +280,10 @@ func TestCoder_Equals(t *testing.T) { want: true, a: NewKV([]*Coder{custom1, ints}), b: NewKV([]*Coder{customSame, ints}), + }, { + want: true, + a: NewN(custom1), + b: NewN(customSame), }, { want: true, a: NewCoGBK([]*Coder{custom1, ints, customSame}), @@ -517,6 +524,60 @@ func TestNewKV(t *testing.T) { } } +func TestNewNullable(t *testing.T) { + bytes := NewBytes() + + tests := []struct { + name string + component *Coder + shouldpanic bool + want *Coder + }{ + { + name: "nil", + component: nil, + shouldpanic: true, + }, + { + name: "empty", + component: &Coder{}, + shouldpanic: true, + }, + { + name: "bytes", + component: bytes, + shouldpanic: false, + want: &Coder{ + Kind: Nullable, + T: typex.New(typex.NullableType, bytes.T), + Components: []*Coder{bytes}, + }, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + if test.shouldpanic { + defer func() { + if p := recover(); p != nil { + t.Log(p) + return + } + t.Fatalf("NewNullable(%v): want panic", test.component) + }() + } + got := NewN(test.component) + if !IsNullable(got) { + t.Errorf("IsNullable(%v) = false, want true", got) + } + if test.want != nil && !test.want.Equals(got) { + t.Fatalf("NewNullable(%v) = %v, want %v", test.component, got, test.want) + } + }) + } +} + func TestNewCoGBK(t *testing.T) { bytes := NewBytes() ints := NewVarInt() diff --git a/sdks/go/pkg/beam/core/graph/coder/map.go b/sdks/go/pkg/beam/core/graph/coder/map.go index 2d72446bf444..30eee7b008a0 100644 --- a/sdks/go/pkg/beam/core/graph/coder/map.go +++ b/sdks/go/pkg/beam/core/graph/coder/map.go @@ -62,24 +62,6 @@ func mapDecoder(rt reflect.Type, decodeToKey, decodeToElem typeDecoderFieldRefle } } -// containerNilDecoder handles when a value is nillable for map or iterable components. -// Nillable types have an extra byte prefixing them indicating nil status. -func containerNilDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error { - return func(ret reflect.Value, r io.Reader) error { - hasValue, err := DecodeBool(r) - if err != nil { - return err - } - if !hasValue { - return nil - } - if err := decodeToElem(ret, r); err != nil { - return err - } - return nil - } -} - // mapEncoder reflectively encodes a map or array type using the beam map encoding. func mapEncoder(rt reflect.Type, encodeKey, encodeValue typeEncoderFieldReflect) func(reflect.Value, io.Writer) error { return func(rv reflect.Value, w io.Writer) error { @@ -132,17 +114,3 @@ func mapEncoder(rt reflect.Type, encodeKey, encodeValue typeEncoderFieldReflect) return nil } } - -// containerNilEncoder handles when a value is nillable for map or iterable components. -// Nillable types have an extra byte prefixing them indicating nil status. -func containerNilEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error { - return func(rv reflect.Value, w io.Writer) error { - if rv.IsNil() { - return EncodeBool(false, w) - } - if err := EncodeBool(true, w); err != nil { - return err - } - return encodeElem(rv, w) - } -} diff --git a/sdks/go/pkg/beam/core/graph/coder/map_test.go b/sdks/go/pkg/beam/core/graph/coder/map_test.go index ee4c35afa609..3291f7fbd5a2 100644 --- a/sdks/go/pkg/beam/core/graph/coder/map_test.go +++ b/sdks/go/pkg/beam/core/graph/coder/map_test.go @@ -38,8 +38,8 @@ func TestEncodeDecodeMap(t *testing.T) { v.Set(reflect.New(reflectx.Uint8)) return byteDec(v.Elem(), r) } - byteCtnrPtrEnc := containerNilEncoder(bytePtrEnc) - byteCtnrPtrDec := containerNilDecoder(bytePtrDec) + byteCtnrPtrEnc := NullableEncoder(bytePtrEnc) + byteCtnrPtrDec := NullableDecoder(bytePtrDec) ptrByte := byte(42) diff --git a/sdks/go/pkg/beam/core/graph/coder/nil.go b/sdks/go/pkg/beam/core/graph/coder/nil.go new file mode 100644 index 000000000000..a7ed27cb6d5e --- /dev/null +++ b/sdks/go/pkg/beam/core/graph/coder/nil.go @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package coder + +import ( + "io" + "reflect" +) + +// NullableDecoder handles when a value is nillable. +// Nillable types have an extra byte prefixing them indicating nil status. +func NullableDecoder(decodeToElem func(reflect.Value, io.Reader) error) func(reflect.Value, io.Reader) error { + return func(ret reflect.Value, r io.Reader) error { + hasValue, err := DecodeBool(r) + if err != nil { + return err + } + if !hasValue { + return nil + } + if err := decodeToElem(ret, r); err != nil { + return err + } + return nil + } +} + +// NullableEncoder handles when a value is nillable. +// Nillable types have an extra byte prefixing them indicating nil status. +func NullableEncoder(encodeElem func(reflect.Value, io.Writer) error) func(reflect.Value, io.Writer) error { + return func(rv reflect.Value, w io.Writer) error { + if rv.IsNil() { + return EncodeBool(false, w) + } + if err := EncodeBool(true, w); err != nil { + return err + } + return encodeElem(rv, w) + } +} diff --git a/sdks/go/pkg/beam/core/graph/coder/nil_test.go b/sdks/go/pkg/beam/core/graph/coder/nil_test.go new file mode 100644 index 000000000000..89410b9c9395 --- /dev/null +++ b/sdks/go/pkg/beam/core/graph/coder/nil_test.go @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package coder + +import ( + "bytes" + "fmt" + "io" + "reflect" + "testing" + + "github.com/apache/beam/sdks/v2/go/pkg/beam/core/util/reflectx" + "github.com/apache/beam/sdks/v2/go/pkg/beam/internal/errors" + "github.com/google/go-cmp/cmp" +) + +func TestEncodeDecodeNullable(t *testing.T) { + byteEnc := func(v reflect.Value, w io.Writer) error { + return EncodeByte(byte(v.Uint()), w) + } + byteDec := func(v reflect.Value, r io.Reader) error { + b, err := DecodeByte(r) + if err != nil { + return errors.Wrap(err, "error decoding single byte field") + } + v.SetUint(uint64(b)) + return nil + } + bytePtrEnc := func(v reflect.Value, w io.Writer) error { + return byteEnc(v.Elem(), w) + } + bytePtrDec := func(v reflect.Value, r io.Reader) error { + v.Set(reflect.New(reflectx.Uint8)) + return byteDec(v.Elem(), r) + } + byteCtnrPtrEnc := NullableEncoder(bytePtrEnc) + byteCtnrPtrDec := NullableDecoder(bytePtrDec) + + tests := []struct { + decoded interface{} + encoded []byte + }{ + { + decoded: (*byte)(nil), + encoded: []byte{0}, + }, + { + decoded: create(10), + encoded: []byte{1, 10}, + }, + { + decoded: create(20), + encoded: []byte{1, 20}, + }, + } + + for _, test := range tests { + t.Run(fmt.Sprintf("encode %q", test.encoded), func(t *testing.T) { + var buf bytes.Buffer + encErr := byteCtnrPtrEnc(reflect.ValueOf(test.decoded), &buf) + if encErr != nil { + t.Fatalf("NullableEncoder(%q) = %v", test.decoded, encErr) + } + if d := cmp.Diff(test.encoded, buf.Bytes()); d != "" { + t.Errorf("NullableEncoder(%q) = %v, want %v diff(-want,+got):\n %v", test.decoded, buf.Bytes(), test.encoded, d) + } + }) + t.Run(fmt.Sprintf("decode %q", test.decoded), func(t *testing.T) { + buf := bytes.NewBuffer(test.encoded) + rv := reflect.New(reflect.TypeOf(test.decoded)).Elem() + decErr := byteCtnrPtrDec(rv, buf) + if decErr != nil { + t.Fatalf("NullableDecoder(%q) = %v", test.encoded, decErr) + } + if d := cmp.Diff(test.decoded, rv.Interface()); d != "" { + t.Errorf("NullableDecoder (%q) = %q, want %v diff(-want,+got):\n %v", test.encoded, rv.Interface(), test.decoded, d) + } + }) + } + +} + +func create(x byte) *byte { + return &x +} diff --git a/sdks/go/pkg/beam/core/graph/coder/row_decoder.go b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go index 1e1fcad32154..9688ed9876c4 100644 --- a/sdks/go/pkg/beam/core/graph/coder/row_decoder.go +++ b/sdks/go/pkg/beam/core/graph/coder/row_decoder.go @@ -386,7 +386,7 @@ func (b *RowDecoderBuilder) containerDecoderForType(t reflect.Type) (typeDecoder return typeDecoderFieldReflect{}, err } if t.Kind() == reflect.Ptr { - return typeDecoderFieldReflect{decode: containerNilDecoder(dec.decode), addr: dec.addr}, nil + return typeDecoderFieldReflect{decode: NullableDecoder(dec.decode), addr: dec.addr}, nil } return dec, nil } diff --git a/sdks/go/pkg/beam/core/graph/coder/row_encoder.go b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go index e12776459da8..cfc1a8e51a3d 100644 --- a/sdks/go/pkg/beam/core/graph/coder/row_encoder.go +++ b/sdks/go/pkg/beam/core/graph/coder/row_encoder.go @@ -262,7 +262,7 @@ func (b *RowEncoderBuilder) containerEncoderForType(t reflect.Type) (typeEncoder return typeEncoderFieldReflect{}, err } if t.Kind() == reflect.Ptr { - return typeEncoderFieldReflect{encode: containerNilEncoder(encf.encode), addr: encf.addr}, nil + return typeEncoderFieldReflect{encode: NullableEncoder(encf.encode), addr: encf.addr}, nil } return encf, nil } diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder.go b/sdks/go/pkg/beam/core/runtime/exec/coder.go index c7a19eae0470..145209a492cd 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/coder.go +++ b/sdks/go/pkg/beam/core/runtime/exec/coder.go @@ -156,6 +156,12 @@ func MakeElementEncoder(c *coder.Coder) ElementEncoder { enc: enc, } + case coder.Nullable: + return &nullableEncoder{ + inner: MakeElementEncoder(c.Components[0]), + be: boolEncoder{}, + } + default: panic(fmt.Sprintf("Unexpected coder: %v", c)) } @@ -267,6 +273,12 @@ func MakeElementDecoder(c *coder.Coder) ElementDecoder { dec: dec, } + case coder.Nullable: + return &nullableDecoder{ + inner: MakeElementDecoder(c.Components[0]), + bd: boolDecoder{}, + } + default: panic(fmt.Sprintf("Unexpected coder: %v", c)) } @@ -609,6 +621,56 @@ func convertIfNeeded(v interface{}, allocated *FullValue) *FullValue { return allocated } +type nullableEncoder struct { + inner ElementEncoder + be boolEncoder +} + +func (n *nullableEncoder) Encode(value *FullValue, writer io.Writer) error { + if value.Elm == nil { + if err := n.be.Encode(&FullValue{Elm: false}, writer); err != nil { + return err + } + return nil + } + if err := n.be.Encode(&FullValue{Elm: true}, writer); err != nil { + return err + } + if err := n.inner.Encode(value, writer); err != nil { + return err + } + return nil +} + +type nullableDecoder struct { + inner ElementDecoder + bd boolDecoder +} + +func (n *nullableDecoder) Decode(reader io.Reader) (*FullValue, error) { + hasValue, err := n.bd.Decode(reader) + if err != nil { + return nil, err + } + if !hasValue.Elm.(bool) { + return &FullValue{}, nil + } + val, err := n.inner.Decode(reader) + if err != nil { + return nil, err + } + return val, nil +} + +func (n *nullableDecoder) DecodeTo(reader io.Reader, value *FullValue) error { + val, err := n.Decode(reader) + if err != nil { + return err + } + value.Elm = val.Elm + return nil +} + type iterableEncoder struct { t reflect.Type enc ElementEncoder diff --git a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go index 02d1f81da7e0..7ee13ecaf9ee 100644 --- a/sdks/go/pkg/beam/core/runtime/exec/coder_test.go +++ b/sdks/go/pkg/beam/core/runtime/exec/coder_test.go @@ -80,6 +80,12 @@ func TestCoders(t *testing.T) { }, { coder: coder.NewPW(coder.NewString(), coder.NewGlobalWindow()), val: &FullValue{Elm: "myString" /*Windowing info isn't encoded for PW so we can omit it here*/}, + }, { + coder: coder.NewN(coder.NewBytes()), + val: &FullValue{}, + }, { + coder: coder.NewN(coder.NewBytes()), + val: &FullValue{Elm: []byte("myBytes")}, }, } { t.Run(fmt.Sprintf("%v", test.coder), func(t *testing.T) { diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder.go b/sdks/go/pkg/beam/core/runtime/graphx/coder.go index a8b897538232..e70567d705bd 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/coder.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/coder.go @@ -46,6 +46,7 @@ const ( urnParamWindowedValueCoder = "beam:coder:param_windowed_value:v1" urnTimerCoder = "beam:coder:timer:v1" urnRowCoder = "beam:coder:row:v1" + urnNullableCoder = "beam:coder:nullable:v1" urnGlobalWindow = "beam:coder:global_window:v1" urnIntervalWindow = "beam:coder:interval_window:v1" @@ -71,6 +72,7 @@ func knownStandardCoders() []string { urnGlobalWindow, urnIntervalWindow, urnRowCoder, + urnNullableCoder, // TODO(BEAM-10660): Add urnTimerCoder once finalized. } } @@ -368,6 +370,15 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, return nil, err } return coder.NewR(typex.New(t)), nil + case urnNullableCoder: + if len(components) != 1 { + return nil, errors.Errorf("could not unmarshal nullable coder from %v, expected one component but got %d", c, len(components)) + } + elm, err := b.Coder(components[0]) + if err != nil { + return nil, err + } + return coder.NewN(elm), nil // Special handling for window coders so they can be treated as // a general coder. Generally window coders are not used outside of @@ -386,7 +397,6 @@ func (b *CoderUnmarshaller) makeCoder(id string, c *pipepb.Coder) (*coder.Coder, return nil, err } return &coder.Coder{Kind: coder.Window, T: typex.New(reflect.TypeOf((*struct{})(nil)).Elem()), Window: w}, nil - default: return nil, errors.Errorf("could not unmarshal coder from %v, unknown URN %v", c, urn) } @@ -465,6 +475,13 @@ func (b *CoderMarshaller) Add(c *coder.Coder) (string, error) { } return b.internBuiltInCoder(urnKVCoder, comp...), nil + case coder.Nullable: + comp, err := b.AddMulti(c.Components) + if err != nil { + return "", errors.Wrapf(err, "failed to marshal Nullable coder %v", c) + } + return b.internBuiltInCoder(urnNullableCoder, comp...), nil + case coder.CoGBK: comp, err := b.AddMulti(c.Components) if err != nil { diff --git a/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go b/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go index 8296c9fd3381..aad15df0f23f 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/coder_test.go @@ -88,6 +88,10 @@ func TestMarshalUnmarshalCoders(t *testing.T) { "W", coder.NewW(coder.NewBytes(), coder.NewGlobalWindow()), }, + { + "N", + coder.NewN(coder.NewBytes()), + }, { "KV", coder.NewKV([]*coder.Coder{foo, bar}), diff --git a/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go b/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go index 77aa6ca46a57..e2eec3b5bcc8 100644 --- a/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go +++ b/sdks/go/pkg/beam/core/runtime/graphx/dataflow.go @@ -48,6 +48,7 @@ const ( doubleType = "kind:double" streamType = "kind:stream" pairType = "kind:pair" + nullableType = "kind:nullable" lengthPrefixType = "kind:length_prefix" rowType = "kind:row" @@ -117,6 +118,16 @@ func EncodeCoderRef(c *coder.Coder) (*CoderRef, error) { } return &CoderRef{Type: pairType, Components: []*CoderRef{key, value}, IsPairLike: true}, nil + case coder.Nullable: + if len(c.Components) != 1 { + return nil, errors.Errorf("bad N: %v", c) + } + innerref, err := EncodeCoderRef(c.Components[0]) + if err != nil { + return nil, err + } + return &CoderRef{Type: nullableType, Components: []*CoderRef{innerref}}, nil + case coder.CoGBK: if len(c.Components) < 2 { return nil, errors.Errorf("bad CoGBK: %v", c) @@ -264,6 +275,19 @@ func DecodeCoderRef(c *CoderRef) (*coder.Coder, error) { t := typex.New(root, key.T, value.T) return &coder.Coder{Kind: kind, T: t, Components: []*coder.Coder{key, value}}, nil + case nullableType: + if len(c.Components) != 1 { + return nil, errors.Errorf("bad nullable: %+v", c) + } + + inner, err := DecodeCoderRef(c.Components[0]) + if err != nil { + return nil, err + } + + t := typex.New(typex.NullableType, inner.T) + return &coder.Coder{Kind: coder.Nullable, T: t, Components: []*coder.Coder{inner}}, nil + case lengthPrefixType: if len(c.Components) != 1 { return nil, errors.Errorf("bad length prefix: %+v", c) diff --git a/sdks/go/pkg/beam/core/typex/fulltype.go b/sdks/go/pkg/beam/core/typex/fulltype.go index cbfb443755ba..df5425a4e1a9 100644 --- a/sdks/go/pkg/beam/core/typex/fulltype.go +++ b/sdks/go/pkg/beam/core/typex/fulltype.go @@ -87,6 +87,8 @@ func printShortComposite(t reflect.Type) string { return "CoGBK" case KVType: return "KV" + case NullableType: + return "Nullable" default: return fmt.Sprintf("invalid(%v)", t) } diff --git a/sdks/go/pkg/beam/core/typex/special.go b/sdks/go/pkg/beam/core/typex/special.go index d13aab562a9a..b45cb61081bc 100644 --- a/sdks/go/pkg/beam/core/typex/special.go +++ b/sdks/go/pkg/beam/core/typex/special.go @@ -38,9 +38,10 @@ var ( WindowType = reflect.TypeOf((*Window)(nil)).Elem() PaneInfoType = reflect.TypeOf((*PaneInfo)(nil)).Elem() - KVType = reflect.TypeOf((*KV)(nil)).Elem() - CoGBKType = reflect.TypeOf((*CoGBK)(nil)).Elem() - WindowedValueType = reflect.TypeOf((*WindowedValue)(nil)).Elem() + KVType = reflect.TypeOf((*KV)(nil)).Elem() + NullableType = reflect.TypeOf((*Nullable)(nil)).Elem() + CoGBKType = reflect.TypeOf((*CoGBK)(nil)).Elem() + WindowedValueType = reflect.TypeOf((*WindowedValue)(nil)).Elem() BundleFinalizationType = reflect.TypeOf((*BundleFinalization)(nil)).Elem() ) @@ -92,6 +93,8 @@ type PaneInfo struct { type KV struct{} +type Nullable struct{} + type CoGBK struct{} type WindowedValue struct{} diff --git a/sdks/go/test/regression/coders/fromyaml/fromyaml.go b/sdks/go/test/regression/coders/fromyaml/fromyaml.go index 82d3e9fdb248..199ff4e2a91d 100644 --- a/sdks/go/test/regression/coders/fromyaml/fromyaml.go +++ b/sdks/go/test/regression/coders/fromyaml/fromyaml.go @@ -218,6 +218,19 @@ func diff(c Coder, elem *exec.FullValue, eg yaml.MapItem) bool { } return pass + case "beam:coder:nullable:v1": + if elem.Elm == nil || eg.Value == nil { + got, want = elem.Elm, eg.Value + } else { + got = string(elem.Elm.([]byte)) + switch egv := eg.Value.(type) { + case string: + want = egv + case []byte: + want = string(egv) + } + } + case "beam:coder:iterable:v1": pass := true gotrv := reflect.ValueOf(elem.Elm) diff --git a/sdks/python/apache_beam/coders/coders.py b/sdks/python/apache_beam/coders/coders.py index 63add754d0f7..fce397df6267 100644 --- a/sdks/python/apache_beam/coders/coders.py +++ b/sdks/python/apache_beam/coders/coders.py @@ -607,6 +607,22 @@ def _create_impl(self): def to_type_hint(self): return typehints.Optional[self._value_coder.to_type_hint()] + def _get_component_coders(self): + # type: () -> List[Coder] + return [self._value_coder] + + @classmethod + def from_type_hint(cls, typehint, registry): + if typehints.is_nullable(typehint): + return cls( + registry.get_coder( + typehints.get_concrete_type_from_nullable(typehint))) + else: + raise TypeError( + 'Typehint is not of nullable type, ' + 'and cannot be converted to a NullableCoder', + typehint) + def is_deterministic(self): # type: () -> bool return self._value_coder.is_deterministic() @@ -619,6 +635,9 @@ def __hash__(self): return hash(type(self)) + hash(self._value_coder) +Coder.register_structured_urn(common_urns.coders.NULLABLE.urn, NullableCoder) + + class VarIntCoder(FastCoder): """Variable-length integer coder.""" def _create_impl(self): @@ -1524,7 +1543,6 @@ def __hash__(self): class StateBackedIterableCoder(FastCoder): - DEFAULT_WRITE_THRESHOLD = 1 def __init__( diff --git a/sdks/python/apache_beam/coders/standard_coders_test.py b/sdks/python/apache_beam/coders/standard_coders_test.py index aa925a3146fb..e25e232597e3 100644 --- a/sdks/python/apache_beam/coders/standard_coders_test.py +++ b/sdks/python/apache_beam/coders/standard_coders_test.py @@ -193,7 +193,9 @@ class StandardCodersTest(unittest.TestCase): value_parser: ShardedKey( key=value_parser(x['key']), shard_id=x['shardId'].encode('utf-8')), 'beam:coder:custom_window:v1': lambda x, - window_parser: window_parser(x['window']) + window_parser: window_parser(x['window']), + 'beam:coder:nullable:v1': lambda x, + value_parser: x.encode('utf-8') if x else None } def test_standard_coders(self): diff --git a/sdks/python/apache_beam/coders/typecoders.py b/sdks/python/apache_beam/coders/typecoders.py index 3bf2d15e7874..a66ebe523697 100644 --- a/sdks/python/apache_beam/coders/typecoders.py +++ b/sdks/python/apache_beam/coders/typecoders.py @@ -65,7 +65,6 @@ def MakeXyzs(v): """ # pytype: skip-file - from typing import Any from typing import Dict from typing import Iterable @@ -138,6 +137,8 @@ def get_coder(self, typehint): return coders.IterableCoder.from_type_hint(typehint, self) elif isinstance(typehint, typehints.ListConstraint): return coders.ListCoder.from_type_hint(typehint, self) + elif typehints.is_nullable(typehint): + return coders.NullableCoder.from_type_hint(typehint, self) elif typehint is None: # In some old code, None is used for Any. # TODO(robertwb): Clean this up. diff --git a/sdks/python/apache_beam/coders/typecoders_test.py b/sdks/python/apache_beam/coders/typecoders_test.py index 02f4565c5e2d..f74483ad48dc 100644 --- a/sdks/python/apache_beam/coders/typecoders_test.py +++ b/sdks/python/apache_beam/coders/typecoders_test.py @@ -140,6 +140,13 @@ def test_list_coder(self): self.assertIs( list, type(expected_coder.decode(expected_coder.encode(values)))) + def test_nullable_coder(self): + expected_coder = coders.NullableCoder(coders.BytesCoder()) + real_coder = typecoders.registry.get_coder(typehints.Optional(bytes)) + self.assertEqual(expected_coder, real_coder) + self.assertEqual(expected_coder.encode(None), real_coder.encode(None)) + self.assertEqual(expected_coder.encode(b'abc'), real_coder.encode(b'abc')) + if __name__ == '__main__': unittest.main() diff --git a/sdks/python/apache_beam/typehints/typehints.py b/sdks/python/apache_beam/typehints/typehints.py index 9b4fc95a7d90..45c2366dd8b8 100644 --- a/sdks/python/apache_beam/typehints/typehints.py +++ b/sdks/python/apache_beam/typehints/typehints.py @@ -503,10 +503,13 @@ def __repr__(self): return 'Union[%s]' % ( ', '.join(sorted(_unified_repr(t) for t in self.union_types))) - def _inner_types(self): + def inner_types(self): for t in self.union_types: yield t + def contains_type(self, maybe_type): + return maybe_type in self.union_types + def _consistent_with_check_(self, sub): if isinstance(sub, UnionConstraint): # A union type is compatible if every possible type is compatible. @@ -601,6 +604,22 @@ def __getitem__(self, py_type): return Union[py_type, type(None)] +def is_nullable(typehint): + return ( + isinstance(typehint, UnionConstraint) and + typehint.contains_type(type(None)) and + len(list(typehint.inner_types())) == 2) + + +def get_concrete_type_from_nullable(typehint): + if is_nullable(typehint): + for inner_type in typehint.inner_types(): + if not type(None) == inner_type: + return inner_type + else: + raise TypeError('Typehint is not of nullable type', typehint) + + class TupleHint(CompositeTypeHint): """A Tuple type-hint. diff --git a/sdks/python/apache_beam/typehints/typehints_test.py b/sdks/python/apache_beam/typehints/typehints_test.py index b3a4d636e9b5..8818639035ff 100644 --- a/sdks/python/apache_beam/typehints/typehints_test.py +++ b/sdks/python/apache_beam/typehints/typehints_test.py @@ -320,6 +320,14 @@ def test_getitem_proxy_to_union(self): hint = typehints.Optional[int] self.assertTrue(isinstance(hint, typehints.UnionHint.UnionConstraint)) + def test_is_optional(self): + hint1 = typehints.Optional[int] + self.assertTrue(typehints.is_nullable(hint1)) + hint2 = typehints.UnionConstraint({int, bytes}) + self.assertFalse(typehints.is_nullable(hint2)) + hint3 = typehints.UnionConstraint({int, bytes, type(None)}) + self.assertFalse(typehints.is_nullable(hint3)) + class TupleHintTestCase(TypeHintTestCase): def test_getitem_invalid_ellipsis_type_param(self):