Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ public static <T> void coderDecodeEncodeEqualInContext(
assertThat(decodeEncode(coder, context, value), equalTo(value));
}

/**
* Verifies that for the given {@code Coder<T>}, {@code Coder.Context}, and value of type {@code
* T}, encoding followed by decoding yields a value of type {@code T} and tests that the matcher
* succeeds on the values.
*/
public static <T> void coderDecodeEncodeInContext(
Coder<T> coder, Coder.Context context, T value, org.hamcrest.Matcher<T> matcher)
throws Exception {
assertThat(decodeEncode(coder, context, value), matcher);
}

/**
* Verifies that for the given {@code Coder<Collection<T>>}, and value of type {@code
* Collection<T>}, encoding followed by decoding yields an equal value of type {@code
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.protobuf;

import com.google.protobuf.Descriptors;
import com.google.protobuf.DynamicMessage;
import com.google.protobuf.Message;
import com.google.protobuf.Parser;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderProvider;
import org.apache.beam.sdk.coders.CoderRegistry;
import org.apache.beam.sdk.coders.DefaultCoder;
import org.apache.beam.sdk.values.TypeDescriptor;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Sets;

/**
* A {@link Coder} using Google Protocol Buffers binary format. {@link DynamicProtoCoder} supports
* both Protocol Buffers syntax versions 2 and 3.
*
* <p>To learn more about Protocol Buffers, visit: <a
* href="https://developers.google.com/protocol-buffers">https://developers.google.com/protocol-buffers</a>
*
* <p>{@link DynamicProtoCoder} is not registered in the global {@link CoderRegistry} as the
* descriptor is required to create the coder.
*/
public class DynamicProtoCoder extends ProtoCoder<DynamicMessage> {

public static final long serialVersionUID = 1L;

/**
* Returns a {@link DynamicProtoCoder} for the Protocol Buffers {@link DynamicMessage} for the
* given {@link Descriptors.Descriptor}.
*/
public static DynamicProtoCoder of(Descriptors.Descriptor protoMessageDescriptor) {
return new DynamicProtoCoder(
ProtoDomain.buildFrom(protoMessageDescriptor),
protoMessageDescriptor.getFullName(),
ImmutableSet.of());
}

/**
* Returns a {@link DynamicProtoCoder} for the Protocol Buffers {@link DynamicMessage} for the
* given {@link Descriptors.Descriptor}. The message descriptor should be part of the provided
* {@link ProtoDomain}, this will ensure object equality within messages from the same domain.
*/
public static DynamicProtoCoder of(
ProtoDomain domain, Descriptors.Descriptor protoMessageDescriptor) {
return new DynamicProtoCoder(domain, protoMessageDescriptor.getFullName(), ImmutableSet.of());
}

/**
* Returns a {@link DynamicProtoCoder} for the Protocol Buffers {@link DynamicMessage} for the
* given message name in a {@link ProtoDomain}. The message descriptor should be part of the
* provided * {@link ProtoDomain}, this will ensure object equality within messages from the same
* domain.
*/
public static DynamicProtoCoder of(ProtoDomain domain, String messageName) {
return new DynamicProtoCoder(domain, messageName, ImmutableSet.of());
}

/**
* Returns a {@link DynamicProtoCoder} like this one, but with the extensions from the given
* classes registered.
*
* <p>Each of the extension host classes must be an class automatically generated by the Protocol
* Buffers compiler, {@code protoc}, that contains messages.
*
* <p>Does not modify this object.
*/
@Override
public DynamicProtoCoder withExtensionsFrom(Iterable<Class<?>> moreExtensionHosts) {
validateExtensions(moreExtensionHosts);
return new DynamicProtoCoder(
this.domain,
this.messageName,
new ImmutableSet.Builder<Class<?>>()
.addAll(extensionHostClasses)
.addAll(moreExtensionHosts)
.build());
}

@Override
public boolean equals(Object other) {
if (this == other) {
return true;
}
if (other == null || getClass() != other.getClass()) {
return false;
}
DynamicProtoCoder otherCoder = (DynamicProtoCoder) other;
return protoMessageClass.equals(otherCoder.protoMessageClass)
&& Sets.newHashSet(extensionHostClasses)
.equals(Sets.newHashSet(otherCoder.extensionHostClasses))
&& domain.equals(otherCoder.domain)
&& messageName.equals(otherCoder.messageName);
}

@Override
public int hashCode() {
return Objects.hash(protoMessageClass, extensionHostClasses, domain, messageName);
}

////////////////////////////////////////////////////////////////////////////////////
// Private implementation details below.

// Constants used to serialize and deserialize
private static final String PROTO_MESSAGE_CLASS = "dynamic_proto_message_class";
private static final String PROTO_EXTENSION_HOSTS = "dynamic_proto_extension_hosts";

// Descriptor used by DynamicMessage.
private transient ProtoDomain domain;
private transient String messageName;

private DynamicProtoCoder(
ProtoDomain domain, String messageName, Set<Class<?>> extensionHostClasses) {
super(DynamicMessage.class, extensionHostClasses);
this.domain = domain;
this.messageName = messageName;
}

private void writeObject(ObjectOutputStream oos) throws IOException {
oos.defaultWriteObject();
oos.writeObject(domain);
oos.writeObject(messageName);
}

private void readObject(ObjectInputStream ois) throws ClassNotFoundException, IOException {
ois.defaultReadObject();
this.domain = (ProtoDomain) ois.readObject();
this.messageName = (String) ois.readObject();
}

/** Get the memoized {@link Parser}, possibly initializing it lazily. */
@Override
protected Parser<DynamicMessage> getParser() {
if (memoizedParser == null) {
DynamicMessage protoMessageInstance =
DynamicMessage.newBuilder(domain.getDescriptor(messageName)).build();
memoizedParser = protoMessageInstance.getParserForType();
}
return memoizedParser;
}

/**
* Returns a {@link CoderProvider} which uses the {@link DynamicProtoCoder} for {@link Message
* proto messages}.
*
* <p>This method is invoked reflectively from {@link DefaultCoder}.
*/
public static CoderProvider getCoderProvider() {
return new ProtoCoderProvider();
}

static final TypeDescriptor<Message> MESSAGE_TYPE = new TypeDescriptor<Message>() {};

/** A {@link CoderProvider} for {@link Message proto messages}. */
private static class ProtoCoderProvider extends CoderProvider {

@Override
public <T> Coder<T> coderFor(
TypeDescriptor<T> typeDescriptor, List<? extends Coder<?>> componentCoders)
throws CannotProvideCoderException {
if (!typeDescriptor.isSubtypeOf(MESSAGE_TYPE)) {
throw new CannotProvideCoderException(
String.format(
"Cannot provide %s because %s is not a subclass of %s",
DynamicProtoCoder.class.getSimpleName(), typeDescriptor, Message.class.getName()));
}

@SuppressWarnings("unchecked")
TypeDescriptor<? extends Message> messageType =
(TypeDescriptor<? extends Message>) typeDescriptor;
try {
@SuppressWarnings("unchecked")
Coder<T> coder = (Coder<T>) DynamicProtoCoder.of(messageType);
return coder;
} catch (IllegalArgumentException e) {
throw new CannotProvideCoderException(e);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;

import com.google.protobuf.DynamicMessage;
import com.google.protobuf.ExtensionRegistry;
import com.google.protobuf.Message;
import com.google.protobuf.Parser;
Expand All @@ -32,8 +33,6 @@
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.CoderException;
Expand Down Expand Up @@ -107,6 +106,8 @@
*/
public class ProtoCoder<T extends Message> extends CustomCoder<T> {

public static final long serialVersionUID = -5043999806040629525L;

/** Returns a {@link ProtoCoder} for the given Protocol Buffers {@link Message}. */
public static <T extends Message> ProtoCoder<T> of(Class<T> protoMessageClass) {
return new ProtoCoder<>(protoMessageClass, ImmutableSet.of());
Expand All @@ -123,15 +124,11 @@ public static <T extends Message> ProtoCoder<T> of(TypeDescriptor<T> protoMessag
}

/**
* Returns a {@link ProtoCoder} like this one, but with the extensions from the given classes
* registered.
* Validate that all extensionHosts are able to be registered.
*
* <p>Each of the extension host classes must be an class automatically generated by the Protocol
* Buffers compiler, {@code protoc}, that contains messages.
*
* <p>Does not modify this object.
* @param moreExtensionHosts
*/
public ProtoCoder<T> withExtensionsFrom(Iterable<Class<?>> moreExtensionHosts) {
void validateExtensions(Iterable<Class<?>> moreExtensionHosts) {
for (Class<?> extensionHost : moreExtensionHosts) {
// Attempt to access the required method, to make sure it's present.
try {
Expand All @@ -146,7 +143,19 @@ public ProtoCoder<T> withExtensionsFrom(Iterable<Class<?>> moreExtensionHosts) {
e);
}
}
}

/**
* Returns a {@link ProtoCoder} like this one, but with the extensions from the given classes
* registered.
*
* <p>Each of the extension host classes must be an class automatically generated by the Protocol
* Buffers compiler, {@code protoc}, that contains messages.
*
* <p>Does not modify this object.
*/
public ProtoCoder<T> withExtensionsFrom(Iterable<Class<?>> moreExtensionHosts) {
validateExtensions(moreExtensionHosts);
return new ProtoCoder<>(
protoMessageClass,
new ImmutableSet.Builder<Class<?>>()
Expand Down Expand Up @@ -200,7 +209,7 @@ public boolean equals(Object other) {
if (this == other) {
return true;
}
if (!(other instanceof ProtoCoder)) {
if (other == null || getClass() != other.getClass()) {
return false;
}
ProtoCoder<?> otherCoder = (ProtoCoder<?>) other;
Expand Down Expand Up @@ -253,37 +262,43 @@ public ExtensionRegistry getExtensionRegistry() {
// Private implementation details below.

/** The {@link Message} type to be coded. */
private final Class<T> protoMessageClass;
final Class<T> protoMessageClass;

/**
* All extension host classes included in this {@link ProtoCoder}. The extensions from these
* classes will be included in the {@link ExtensionRegistry} used during encoding and decoding.
*/
private final Set<Class<?>> extensionHostClasses;
final Set<Class<?>> extensionHostClasses;

// Constants used to serialize and deserialize
private static final String PROTO_MESSAGE_CLASS = "proto_message_class";
private static final String PROTO_EXTENSION_HOSTS = "proto_extension_hosts";

// Transient fields that are lazy initialized and then memoized.
private transient ExtensionRegistry memoizedExtensionRegistry;
private transient Parser<T> memoizedParser;
transient Parser<T> memoizedParser;

/** Private constructor. */
private ProtoCoder(Class<T> protoMessageClass, Set<Class<?>> extensionHostClasses) {
protected ProtoCoder(Class<T> protoMessageClass, Set<Class<?>> extensionHostClasses) {
this.protoMessageClass = protoMessageClass;
this.extensionHostClasses = extensionHostClasses;
}

/** Get the memoized {@link Parser}, possibly initializing it lazily. */
private Parser<T> getParser() {
protected Parser<T> getParser() {
if (memoizedParser == null) {
try {
@SuppressWarnings("unchecked")
T protoMessageInstance = (T) protoMessageClass.getMethod("getDefaultInstance").invoke(null);
@SuppressWarnings("unchecked")
Parser<T> tParser = (Parser<T>) protoMessageInstance.getParserForType();
memoizedParser = tParser;
if (DynamicMessage.class.equals(protoMessageClass)) {
throw new IllegalArgumentException(
"DynamicMessage is not supported by the ProtoCoder, use the DynamicProtoCoder.");
} else {
@SuppressWarnings("unchecked")
T protoMessageInstance =
(T) protoMessageClass.getMethod("getDefaultInstance").invoke(null);
@SuppressWarnings("unchecked")
Parser<T> tParser = (Parser<T>) protoMessageInstance.getParserForType();
memoizedParser = tParser;
}
} catch (IllegalAccessException | InvocationTargetException | NoSuchMethodException e) {
throw new IllegalArgumentException(e);
}
Expand Down Expand Up @@ -329,12 +344,4 @@ public <T> Coder<T> coderFor(
}
}
}

private SortedSet<String> getSortedExtensionClasses() {
SortedSet<String> ret = new TreeSet<>();
for (Class<?> clazz : extensionHostClasses) {
ret.add(clazz.getName());
}
return ret;
}
}
Loading